buffers.rs

  1use super::*;
  2use prost::Message;
  3use text::{EditOperation, UndoOperation};
  4
  5impl Database {
  6    pub async fn join_channel_buffer(
  7        &self,
  8        channel_id: ChannelId,
  9        user_id: UserId,
 10        connection: ConnectionId,
 11    ) -> Result<proto::JoinChannelBufferResponse> {
 12        self.transaction(|tx| async move {
 13            self.check_user_is_channel_member(channel_id, user_id, &tx)
 14                .await?;
 15
 16            let buffer = channel::Model {
 17                id: channel_id,
 18                ..Default::default()
 19            }
 20            .find_related(buffer::Entity)
 21            .one(&*tx)
 22            .await?;
 23
 24            let buffer = if let Some(buffer) = buffer {
 25                buffer
 26            } else {
 27                let buffer = buffer::ActiveModel {
 28                    channel_id: ActiveValue::Set(channel_id),
 29                    ..Default::default()
 30                }
 31                .insert(&*tx)
 32                .await?;
 33                buffer_snapshot::ActiveModel {
 34                    buffer_id: ActiveValue::Set(buffer.id),
 35                    epoch: ActiveValue::Set(0),
 36                    text: ActiveValue::Set(String::new()),
 37                    operation_serialization_version: ActiveValue::Set(
 38                        storage::SERIALIZATION_VERSION,
 39                    ),
 40                }
 41                .insert(&*tx)
 42                .await?;
 43                buffer
 44            };
 45
 46            // Join the collaborators
 47            let mut collaborators = channel_buffer_collaborator::Entity::find()
 48                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 49                .all(&*tx)
 50                .await?;
 51            let replica_ids = collaborators
 52                .iter()
 53                .map(|c| c.replica_id)
 54                .collect::<HashSet<_>>();
 55            let mut replica_id = ReplicaId(0);
 56            while replica_ids.contains(&replica_id) {
 57                replica_id.0 += 1;
 58            }
 59            let collaborator = channel_buffer_collaborator::ActiveModel {
 60                channel_id: ActiveValue::Set(channel_id),
 61                connection_id: ActiveValue::Set(connection.id as i32),
 62                connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
 63                user_id: ActiveValue::Set(user_id),
 64                replica_id: ActiveValue::Set(replica_id),
 65                ..Default::default()
 66            }
 67            .insert(&*tx)
 68            .await?;
 69            collaborators.push(collaborator);
 70
 71            let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
 72
 73            Ok(proto::JoinChannelBufferResponse {
 74                buffer_id: buffer.id.to_proto(),
 75                replica_id: replica_id.to_proto() as u32,
 76                base_text,
 77                operations,
 78                epoch: buffer.epoch as u64,
 79                collaborators: collaborators
 80                    .into_iter()
 81                    .map(|collaborator| proto::Collaborator {
 82                        peer_id: Some(collaborator.connection().into()),
 83                        user_id: collaborator.user_id.to_proto(),
 84                        replica_id: collaborator.replica_id.0 as u32,
 85                    })
 86                    .collect(),
 87            })
 88        })
 89        .await
 90    }
 91
 92    pub async fn rejoin_channel_buffers(
 93        &self,
 94        buffers: &[proto::ChannelBufferVersion],
 95        user_id: UserId,
 96        connection_id: ConnectionId,
 97    ) -> Result<proto::RejoinChannelBuffersResponse> {
 98        self.transaction(|tx| async move {
 99            let mut response = proto::RejoinChannelBuffersResponse::default();
100            for client_buffer in buffers {
101                let channel_id = ChannelId::from_proto(client_buffer.channel_id);
102                if self
103                    .check_user_is_channel_member(channel_id, user_id, &*tx)
104                    .await
105                    .is_err()
106                {
107                    log::info!("user is not a member of channel");
108                    continue;
109                }
110
111                let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
112                let mut collaborators = channel_buffer_collaborator::Entity::find()
113                    .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
114                    .all(&*tx)
115                    .await?;
116
117                // If the buffer epoch hasn't changed since the client lost
118                // connection, then the client's buffer can be syncronized with
119                // the server's buffer.
120                if buffer.epoch as u64 != client_buffer.epoch {
121                    continue;
122                }
123
124                // If there is still a disconnected collaborator for the user,
125                // update the connection associated with that collaborator, and reuse
126                // that replica id.
127                if let Some(ix) = collaborators
128                    .iter()
129                    .position(|c| c.user_id == user_id && c.connection_lost)
130                {
131                    let self_collaborator = &mut collaborators[ix];
132                    *self_collaborator = channel_buffer_collaborator::ActiveModel {
133                        id: ActiveValue::Unchanged(self_collaborator.id),
134                        connection_id: ActiveValue::Set(connection_id.id as i32),
135                        connection_server_id: ActiveValue::Set(ServerId(
136                            connection_id.owner_id as i32,
137                        )),
138                        connection_lost: ActiveValue::Set(false),
139                        ..Default::default()
140                    }
141                    .update(&*tx)
142                    .await?;
143                } else {
144                    continue;
145                }
146
147                let client_version = version_from_wire(&client_buffer.version);
148                let serialization_version = self
149                    .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
150                    .await?;
151
152                let mut rows = buffer_operation::Entity::find()
153                    .filter(
154                        buffer_operation::Column::BufferId
155                            .eq(buffer.id)
156                            .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
157                    )
158                    .stream(&*tx)
159                    .await?;
160
161                // Find the server's version vector and any operations
162                // that the client has not seen.
163                let mut server_version = clock::Global::new();
164                let mut operations = Vec::new();
165                while let Some(row) = rows.next().await {
166                    let row = row?;
167                    let timestamp = clock::Lamport {
168                        replica_id: row.replica_id as u16,
169                        value: row.lamport_timestamp as u32,
170                    };
171                    server_version.observe(timestamp);
172                    if !client_version.observed(timestamp) {
173                        operations.push(proto::Operation {
174                            variant: Some(operation_from_storage(row, serialization_version)?),
175                        })
176                    }
177                }
178
179                response.buffers.push(proto::RejoinedChannelBuffer {
180                    channel_id: client_buffer.channel_id,
181                    version: version_to_wire(&server_version),
182                    operations,
183                    collaborators: collaborators
184                        .into_iter()
185                        .map(|collaborator| proto::Collaborator {
186                            peer_id: Some(collaborator.connection().into()),
187                            user_id: collaborator.user_id.to_proto(),
188                            replica_id: collaborator.replica_id.0 as u32,
189                        })
190                        .collect(),
191                });
192            }
193
194            Ok(response)
195        })
196        .await
197    }
198
199    pub async fn leave_channel_buffer(
200        &self,
201        channel_id: ChannelId,
202        connection: ConnectionId,
203    ) -> Result<Vec<ConnectionId>> {
204        self.transaction(|tx| async move {
205            self.leave_channel_buffer_internal(channel_id, connection, &*tx)
206                .await
207        })
208        .await
209    }
210
211    pub async fn leave_channel_buffers(
212        &self,
213        connection: ConnectionId,
214    ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
215        self.transaction(|tx| async move {
216            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
217            enum QueryChannelIds {
218                ChannelId,
219            }
220
221            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
222                .select_only()
223                .column(channel_buffer_collaborator::Column::ChannelId)
224                .filter(Condition::all().add(
225                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
226                ))
227                .into_values::<_, QueryChannelIds>()
228                .all(&*tx)
229                .await?;
230
231            let mut result = Vec::new();
232            for channel_id in channel_ids {
233                let collaborators = self
234                    .leave_channel_buffer_internal(channel_id, connection, &*tx)
235                    .await?;
236                result.push((channel_id, collaborators));
237            }
238
239            Ok(result)
240        })
241        .await
242    }
243
244    pub async fn leave_channel_buffer_internal(
245        &self,
246        channel_id: ChannelId,
247        connection: ConnectionId,
248        tx: &DatabaseTransaction,
249    ) -> Result<Vec<ConnectionId>> {
250        let result = channel_buffer_collaborator::Entity::delete_many()
251            .filter(
252                Condition::all()
253                    .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
254                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
255                    .add(
256                        channel_buffer_collaborator::Column::ConnectionServerId
257                            .eq(connection.owner_id as i32),
258                    ),
259            )
260            .exec(&*tx)
261            .await?;
262        if result.rows_affected == 0 {
263            Err(anyhow!("not a collaborator on this project"))?;
264        }
265
266        let mut connections = Vec::new();
267        let mut rows = channel_buffer_collaborator::Entity::find()
268            .filter(
269                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
270            )
271            .stream(&*tx)
272            .await?;
273        while let Some(row) = rows.next().await {
274            let row = row?;
275            connections.push(ConnectionId {
276                id: row.connection_id as u32,
277                owner_id: row.connection_server_id.0 as u32,
278            });
279        }
280
281        drop(rows);
282
283        if connections.is_empty() {
284            self.snapshot_channel_buffer(channel_id, &tx).await?;
285        }
286
287        Ok(connections)
288    }
289
290    pub async fn get_channel_buffer_collaborators(
291        &self,
292        channel_id: ChannelId,
293    ) -> Result<Vec<UserId>> {
294        self.transaction(|tx| async move {
295            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
296            enum QueryUserIds {
297                UserId,
298            }
299
300            let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
301                .select_only()
302                .column(channel_buffer_collaborator::Column::UserId)
303                .filter(
304                    Condition::all()
305                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
306                )
307                .into_values::<_, QueryUserIds>()
308                .all(&*tx)
309                .await?;
310
311            Ok(users)
312        })
313        .await
314    }
315
316    pub async fn update_channel_buffer(
317        &self,
318        channel_id: ChannelId,
319        user: UserId,
320        operations: &[proto::Operation],
321    ) -> Result<Vec<ConnectionId>> {
322        self.transaction(move |tx| async move {
323            self.check_user_is_channel_member(channel_id, user, &*tx)
324                .await?;
325
326            let buffer = buffer::Entity::find()
327                .filter(buffer::Column::ChannelId.eq(channel_id))
328                .one(&*tx)
329                .await?
330                .ok_or_else(|| anyhow!("no such buffer"))?;
331
332            let serialization_version = self
333                .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
334                .await?;
335
336            let operations = operations
337                .iter()
338                .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
339                .collect::<Vec<_>>();
340            if !operations.is_empty() {
341                buffer_operation::Entity::insert_many(operations)
342                    .exec(&*tx)
343                    .await?;
344            }
345
346            let mut connections = Vec::new();
347            let mut rows = channel_buffer_collaborator::Entity::find()
348                .filter(
349                    Condition::all()
350                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
351                )
352                .stream(&*tx)
353                .await?;
354            while let Some(row) = rows.next().await {
355                let row = row?;
356                connections.push(ConnectionId {
357                    id: row.connection_id as u32,
358                    owner_id: row.connection_server_id.0 as u32,
359                });
360            }
361
362            Ok(connections)
363        })
364        .await
365    }
366
367    async fn get_buffer_operation_serialization_version(
368        &self,
369        buffer_id: BufferId,
370        epoch: i32,
371        tx: &DatabaseTransaction,
372    ) -> Result<i32> {
373        Ok(buffer_snapshot::Entity::find()
374            .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
375            .filter(buffer_snapshot::Column::Epoch.eq(epoch))
376            .select_only()
377            .column(buffer_snapshot::Column::OperationSerializationVersion)
378            .into_values::<_, QueryOperationSerializationVersion>()
379            .one(&*tx)
380            .await?
381            .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
382    }
383
384    async fn get_channel_buffer(
385        &self,
386        channel_id: ChannelId,
387        tx: &DatabaseTransaction,
388    ) -> Result<buffer::Model> {
389        Ok(channel::Model {
390            id: channel_id,
391            ..Default::default()
392        }
393        .find_related(buffer::Entity)
394        .one(&*tx)
395        .await?
396        .ok_or_else(|| anyhow!("no such buffer"))?)
397    }
398
399    async fn get_buffer_state(
400        &self,
401        buffer: &buffer::Model,
402        tx: &DatabaseTransaction,
403    ) -> Result<(String, Vec<proto::Operation>)> {
404        let id = buffer.id;
405        let (base_text, version) = if buffer.epoch > 0 {
406            let snapshot = buffer_snapshot::Entity::find()
407                .filter(
408                    buffer_snapshot::Column::BufferId
409                        .eq(id)
410                        .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
411                )
412                .one(&*tx)
413                .await?
414                .ok_or_else(|| anyhow!("no such snapshot"))?;
415
416            let version = snapshot.operation_serialization_version;
417            (snapshot.text, version)
418        } else {
419            (String::new(), storage::SERIALIZATION_VERSION)
420        };
421
422        let mut rows = buffer_operation::Entity::find()
423            .filter(
424                buffer_operation::Column::BufferId
425                    .eq(id)
426                    .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
427            )
428            .stream(&*tx)
429            .await?;
430        let mut operations = Vec::new();
431        while let Some(row) = rows.next().await {
432            operations.push(proto::Operation {
433                variant: Some(operation_from_storage(row?, version)?),
434            })
435        }
436
437        Ok((base_text, operations))
438    }
439
440    async fn snapshot_channel_buffer(
441        &self,
442        channel_id: ChannelId,
443        tx: &DatabaseTransaction,
444    ) -> Result<()> {
445        let buffer = self.get_channel_buffer(channel_id, tx).await?;
446        let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
447        if operations.is_empty() {
448            return Ok(());
449        }
450
451        let mut text_buffer = text::Buffer::new(0, 0, base_text);
452        text_buffer
453            .apply_ops(operations.into_iter().filter_map(operation_from_wire))
454            .unwrap();
455
456        let base_text = text_buffer.text();
457        let epoch = buffer.epoch + 1;
458
459        buffer_snapshot::Model {
460            buffer_id: buffer.id,
461            epoch,
462            text: base_text,
463            operation_serialization_version: storage::SERIALIZATION_VERSION,
464        }
465        .into_active_model()
466        .insert(tx)
467        .await?;
468
469        buffer::ActiveModel {
470            id: ActiveValue::Unchanged(buffer.id),
471            epoch: ActiveValue::Set(epoch),
472            ..Default::default()
473        }
474        .save(tx)
475        .await?;
476
477        Ok(())
478    }
479}
480
481fn operation_to_storage(
482    operation: &proto::Operation,
483    buffer: &buffer::Model,
484    _format: i32,
485) -> Option<buffer_operation::ActiveModel> {
486    let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
487        proto::operation::Variant::Edit(operation) => (
488            operation.replica_id,
489            operation.lamport_timestamp,
490            storage::Operation {
491                version: version_to_storage(&operation.version),
492                is_undo: false,
493                edit_ranges: operation
494                    .ranges
495                    .iter()
496                    .map(|range| storage::Range {
497                        start: range.start,
498                        end: range.end,
499                    })
500                    .collect(),
501                edit_texts: operation.new_text.clone(),
502                undo_counts: Vec::new(),
503            },
504        ),
505        proto::operation::Variant::Undo(operation) => (
506            operation.replica_id,
507            operation.lamport_timestamp,
508            storage::Operation {
509                version: version_to_storage(&operation.version),
510                is_undo: true,
511                edit_ranges: Vec::new(),
512                edit_texts: Vec::new(),
513                undo_counts: operation
514                    .counts
515                    .iter()
516                    .map(|entry| storage::UndoCount {
517                        replica_id: entry.replica_id,
518                        lamport_timestamp: entry.lamport_timestamp,
519                        count: entry.count,
520                    })
521                    .collect(),
522            },
523        ),
524        _ => None?,
525    };
526
527    Some(buffer_operation::ActiveModel {
528        buffer_id: ActiveValue::Set(buffer.id),
529        epoch: ActiveValue::Set(buffer.epoch),
530        replica_id: ActiveValue::Set(replica_id as i32),
531        lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
532        value: ActiveValue::Set(value.encode_to_vec()),
533    })
534}
535
536fn operation_from_storage(
537    row: buffer_operation::Model,
538    _format_version: i32,
539) -> Result<proto::operation::Variant, Error> {
540    let operation =
541        storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
542    let version = version_from_storage(&operation.version);
543    Ok(if operation.is_undo {
544        proto::operation::Variant::Undo(proto::operation::Undo {
545            replica_id: row.replica_id as u32,
546            lamport_timestamp: row.lamport_timestamp as u32,
547            version,
548            counts: operation
549                .undo_counts
550                .iter()
551                .map(|entry| proto::UndoCount {
552                    replica_id: entry.replica_id,
553                    lamport_timestamp: entry.lamport_timestamp,
554                    count: entry.count,
555                })
556                .collect(),
557        })
558    } else {
559        proto::operation::Variant::Edit(proto::operation::Edit {
560            replica_id: row.replica_id as u32,
561            lamport_timestamp: row.lamport_timestamp as u32,
562            version,
563            ranges: operation
564                .edit_ranges
565                .into_iter()
566                .map(|range| proto::Range {
567                    start: range.start,
568                    end: range.end,
569                })
570                .collect(),
571            new_text: operation.edit_texts,
572        })
573    })
574}
575
576fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
577    version
578        .iter()
579        .map(|entry| storage::VectorClockEntry {
580            replica_id: entry.replica_id,
581            timestamp: entry.timestamp,
582        })
583        .collect()
584}
585
586fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
587    version
588        .iter()
589        .map(|entry| proto::VectorClockEntry {
590            replica_id: entry.replica_id,
591            timestamp: entry.timestamp,
592        })
593        .collect()
594}
595
596// This is currently a manual copy of the deserialization code in the client's langauge crate
597pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
598    match operation.variant? {
599        proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
600            timestamp: clock::Lamport {
601                replica_id: edit.replica_id as text::ReplicaId,
602                value: edit.lamport_timestamp,
603            },
604            version: version_from_wire(&edit.version),
605            ranges: edit
606                .ranges
607                .into_iter()
608                .map(|range| {
609                    text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
610                })
611                .collect(),
612            new_text: edit.new_text.into_iter().map(Arc::from).collect(),
613        })),
614        proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
615            timestamp: clock::Lamport {
616                replica_id: undo.replica_id as text::ReplicaId,
617                value: undo.lamport_timestamp,
618            },
619            version: version_from_wire(&undo.version),
620            counts: undo
621                .counts
622                .into_iter()
623                .map(|c| {
624                    (
625                        clock::Lamport {
626                            replica_id: c.replica_id as text::ReplicaId,
627                            value: c.lamport_timestamp,
628                        },
629                        c.count,
630                    )
631                })
632                .collect(),
633        })),
634        _ => None,
635    }
636}
637
638fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
639    let mut version = clock::Global::new();
640    for entry in message {
641        version.observe(clock::Lamport {
642            replica_id: entry.replica_id as text::ReplicaId,
643            value: entry.timestamp,
644        });
645    }
646    version
647}
648
649fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
650    let mut message = Vec::new();
651    for entry in version.iter() {
652        message.push(proto::VectorClockEntry {
653            replica_id: entry.replica_id as u32,
654            timestamp: entry.value,
655        });
656    }
657    message
658}
659
660#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
661enum QueryOperationSerializationVersion {
662    OperationSerializationVersion,
663}
664
665mod storage {
666    #![allow(non_snake_case)]
667    use prost::Message;
668    pub const SERIALIZATION_VERSION: i32 = 1;
669
670    #[derive(Message)]
671    pub struct Operation {
672        #[prost(message, repeated, tag = "2")]
673        pub version: Vec<VectorClockEntry>,
674        #[prost(bool, tag = "3")]
675        pub is_undo: bool,
676        #[prost(message, repeated, tag = "4")]
677        pub edit_ranges: Vec<Range>,
678        #[prost(string, repeated, tag = "5")]
679        pub edit_texts: Vec<String>,
680        #[prost(message, repeated, tag = "6")]
681        pub undo_counts: Vec<UndoCount>,
682    }
683
684    #[derive(Message)]
685    pub struct VectorClockEntry {
686        #[prost(uint32, tag = "1")]
687        pub replica_id: u32,
688        #[prost(uint32, tag = "2")]
689        pub timestamp: u32,
690    }
691
692    #[derive(Message)]
693    pub struct Range {
694        #[prost(uint64, tag = "1")]
695        pub start: u64,
696        #[prost(uint64, tag = "2")]
697        pub end: u64,
698    }
699
700    #[derive(Message)]
701    pub struct UndoCount {
702        #[prost(uint32, tag = "1")]
703        pub replica_id: u32,
704        #[prost(uint32, tag = "2")]
705        pub lamport_timestamp: u32,
706        #[prost(uint32, tag = "3")]
707        pub count: u32,
708    }
709}