buffers.rs

  1use super::*;
  2use prost::Message;
  3use text::{EditOperation, UndoOperation};
  4
  5impl Database {
  6    pub async fn join_channel_buffer(
  7        &self,
  8        channel_id: ChannelId,
  9        user_id: UserId,
 10        connection: ConnectionId,
 11    ) -> Result<proto::JoinChannelBufferResponse> {
 12        self.transaction(|tx| async move {
 13            self.check_user_is_channel_member(channel_id, user_id, &tx)
 14                .await?;
 15
 16            let buffer = channel::Model {
 17                id: channel_id,
 18                ..Default::default()
 19            }
 20            .find_related(buffer::Entity)
 21            .one(&*tx)
 22            .await?;
 23
 24            let buffer = if let Some(buffer) = buffer {
 25                buffer
 26            } else {
 27                let buffer = buffer::ActiveModel {
 28                    channel_id: ActiveValue::Set(channel_id),
 29                    ..Default::default()
 30                }
 31                .insert(&*tx)
 32                .await?;
 33                buffer_snapshot::ActiveModel {
 34                    buffer_id: ActiveValue::Set(buffer.id),
 35                    epoch: ActiveValue::Set(0),
 36                    text: ActiveValue::Set(String::new()),
 37                    operation_serialization_version: ActiveValue::Set(
 38                        storage::SERIALIZATION_VERSION,
 39                    ),
 40                }
 41                .insert(&*tx)
 42                .await?;
 43                buffer
 44            };
 45
 46            // Join the collaborators
 47            let mut collaborators = channel_buffer_collaborator::Entity::find()
 48                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 49                .all(&*tx)
 50                .await?;
 51            let replica_ids = collaborators
 52                .iter()
 53                .map(|c| c.replica_id)
 54                .collect::<HashSet<_>>();
 55            let mut replica_id = ReplicaId(0);
 56            while replica_ids.contains(&replica_id) {
 57                replica_id.0 += 1;
 58            }
 59            let collaborator = channel_buffer_collaborator::ActiveModel {
 60                channel_id: ActiveValue::Set(channel_id),
 61                connection_id: ActiveValue::Set(connection.id as i32),
 62                connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
 63                user_id: ActiveValue::Set(user_id),
 64                replica_id: ActiveValue::Set(replica_id),
 65                ..Default::default()
 66            }
 67            .insert(&*tx)
 68            .await?;
 69            collaborators.push(collaborator);
 70
 71            let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
 72
 73            Ok(proto::JoinChannelBufferResponse {
 74                buffer_id: buffer.id.to_proto(),
 75                replica_id: replica_id.to_proto() as u32,
 76                base_text,
 77                operations,
 78                epoch: buffer.epoch as u64,
 79                collaborators: collaborators
 80                    .into_iter()
 81                    .map(|collaborator| proto::Collaborator {
 82                        peer_id: Some(collaborator.connection().into()),
 83                        user_id: collaborator.user_id.to_proto(),
 84                        replica_id: collaborator.replica_id.0 as u32,
 85                    })
 86                    .collect(),
 87            })
 88        })
 89        .await
 90    }
 91
 92    pub async fn rejoin_channel_buffers(
 93        &self,
 94        buffers: &[proto::ChannelBufferVersion],
 95        user_id: UserId,
 96        connection_id: ConnectionId,
 97    ) -> Result<Vec<RejoinedChannelBuffer>> {
 98        self.transaction(|tx| async move {
 99            let mut results = Vec::new();
100            for client_buffer in buffers {
101                let channel_id = ChannelId::from_proto(client_buffer.channel_id);
102                if self
103                    .check_user_is_channel_member(channel_id, user_id, &*tx)
104                    .await
105                    .is_err()
106                {
107                    log::info!("user is not a member of channel");
108                    continue;
109                }
110
111                let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
112                let mut collaborators = channel_buffer_collaborator::Entity::find()
113                    .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
114                    .all(&*tx)
115                    .await?;
116
117                // If the buffer epoch hasn't changed since the client lost
118                // connection, then the client's buffer can be syncronized with
119                // the server's buffer.
120                if buffer.epoch as u64 != client_buffer.epoch {
121                    continue;
122                }
123
124                // Find the collaborator record for this user's previous lost
125                // connection. Update it with the new connection id.
126                let server_id = ServerId(connection_id.owner_id as i32);
127                let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
128                    c.user_id == user_id
129                        && (c.connection_lost || c.connection_server_id != server_id)
130                }) else {
131                    continue;
132                };
133                let old_connection_id = self_collaborator.connection();
134                *self_collaborator = channel_buffer_collaborator::ActiveModel {
135                    id: ActiveValue::Unchanged(self_collaborator.id),
136                    connection_id: ActiveValue::Set(connection_id.id as i32),
137                    connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
138                    connection_lost: ActiveValue::Set(false),
139                    ..Default::default()
140                }
141                .update(&*tx)
142                .await?;
143
144                let client_version = version_from_wire(&client_buffer.version);
145                let serialization_version = self
146                    .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
147                    .await?;
148
149                let mut rows = buffer_operation::Entity::find()
150                    .filter(
151                        buffer_operation::Column::BufferId
152                            .eq(buffer.id)
153                            .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
154                    )
155                    .stream(&*tx)
156                    .await?;
157
158                // Find the server's version vector and any operations
159                // that the client has not seen.
160                let mut server_version = clock::Global::new();
161                let mut operations = Vec::new();
162                while let Some(row) = rows.next().await {
163                    let row = row?;
164                    let timestamp = clock::Lamport {
165                        replica_id: row.replica_id as u16,
166                        value: row.lamport_timestamp as u32,
167                    };
168                    server_version.observe(timestamp);
169                    if !client_version.observed(timestamp) {
170                        operations.push(proto::Operation {
171                            variant: Some(operation_from_storage(row, serialization_version)?),
172                        })
173                    }
174                }
175
176                results.push(RejoinedChannelBuffer {
177                    old_connection_id,
178                    buffer: proto::RejoinedChannelBuffer {
179                        channel_id: client_buffer.channel_id,
180                        version: version_to_wire(&server_version),
181                        operations,
182                        collaborators: collaborators
183                            .into_iter()
184                            .map(|collaborator| proto::Collaborator {
185                                peer_id: Some(collaborator.connection().into()),
186                                user_id: collaborator.user_id.to_proto(),
187                                replica_id: collaborator.replica_id.0 as u32,
188                            })
189                            .collect(),
190                    },
191                });
192            }
193
194            Ok(results)
195        })
196        .await
197    }
198
199    pub async fn refresh_channel_buffer(
200        &self,
201        channel_id: ChannelId,
202        server_id: ServerId,
203    ) -> Result<RefreshedChannelBuffer> {
204        self.transaction(|tx| async move {
205            let mut connection_ids = Vec::new();
206            let mut removed_collaborators = Vec::new();
207
208            // TODO
209
210            Ok(RefreshedChannelBuffer {
211                connection_ids,
212                removed_collaborators,
213            })
214        })
215        .await
216    }
217
218    pub async fn leave_channel_buffer(
219        &self,
220        channel_id: ChannelId,
221        connection: ConnectionId,
222    ) -> Result<Vec<ConnectionId>> {
223        self.transaction(|tx| async move {
224            self.leave_channel_buffer_internal(channel_id, connection, &*tx)
225                .await
226        })
227        .await
228    }
229
230    pub async fn leave_channel_buffers(
231        &self,
232        connection: ConnectionId,
233    ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
234        self.transaction(|tx| async move {
235            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
236            enum QueryChannelIds {
237                ChannelId,
238            }
239
240            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
241                .select_only()
242                .column(channel_buffer_collaborator::Column::ChannelId)
243                .filter(Condition::all().add(
244                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
245                ))
246                .into_values::<_, QueryChannelIds>()
247                .all(&*tx)
248                .await?;
249
250            let mut result = Vec::new();
251            for channel_id in channel_ids {
252                let collaborators = self
253                    .leave_channel_buffer_internal(channel_id, connection, &*tx)
254                    .await?;
255                result.push((channel_id, collaborators));
256            }
257
258            Ok(result)
259        })
260        .await
261    }
262
263    pub async fn leave_channel_buffer_internal(
264        &self,
265        channel_id: ChannelId,
266        connection: ConnectionId,
267        tx: &DatabaseTransaction,
268    ) -> Result<Vec<ConnectionId>> {
269        let result = channel_buffer_collaborator::Entity::delete_many()
270            .filter(
271                Condition::all()
272                    .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
273                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
274                    .add(
275                        channel_buffer_collaborator::Column::ConnectionServerId
276                            .eq(connection.owner_id as i32),
277                    ),
278            )
279            .exec(&*tx)
280            .await?;
281        if result.rows_affected == 0 {
282            Err(anyhow!("not a collaborator on this project"))?;
283        }
284
285        let mut connections = Vec::new();
286        let mut rows = channel_buffer_collaborator::Entity::find()
287            .filter(
288                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
289            )
290            .stream(&*tx)
291            .await?;
292        while let Some(row) = rows.next().await {
293            let row = row?;
294            connections.push(ConnectionId {
295                id: row.connection_id as u32,
296                owner_id: row.connection_server_id.0 as u32,
297            });
298        }
299
300        drop(rows);
301
302        if connections.is_empty() {
303            self.snapshot_channel_buffer(channel_id, &tx).await?;
304        }
305
306        Ok(connections)
307    }
308
309    pub async fn get_channel_buffer_collaborators(
310        &self,
311        channel_id: ChannelId,
312    ) -> Result<Vec<UserId>> {
313        self.transaction(|tx| async move {
314            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
315            enum QueryUserIds {
316                UserId,
317            }
318
319            let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
320                .select_only()
321                .column(channel_buffer_collaborator::Column::UserId)
322                .filter(
323                    Condition::all()
324                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
325                )
326                .into_values::<_, QueryUserIds>()
327                .all(&*tx)
328                .await?;
329
330            Ok(users)
331        })
332        .await
333    }
334
335    pub async fn update_channel_buffer(
336        &self,
337        channel_id: ChannelId,
338        user: UserId,
339        operations: &[proto::Operation],
340    ) -> Result<Vec<ConnectionId>> {
341        self.transaction(move |tx| async move {
342            self.check_user_is_channel_member(channel_id, user, &*tx)
343                .await?;
344
345            let buffer = buffer::Entity::find()
346                .filter(buffer::Column::ChannelId.eq(channel_id))
347                .one(&*tx)
348                .await?
349                .ok_or_else(|| anyhow!("no such buffer"))?;
350
351            let serialization_version = self
352                .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
353                .await?;
354
355            let operations = operations
356                .iter()
357                .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
358                .collect::<Vec<_>>();
359            if !operations.is_empty() {
360                buffer_operation::Entity::insert_many(operations)
361                    .exec(&*tx)
362                    .await?;
363            }
364
365            let mut connections = Vec::new();
366            let mut rows = channel_buffer_collaborator::Entity::find()
367                .filter(
368                    Condition::all()
369                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
370                )
371                .stream(&*tx)
372                .await?;
373            while let Some(row) = rows.next().await {
374                let row = row?;
375                connections.push(ConnectionId {
376                    id: row.connection_id as u32,
377                    owner_id: row.connection_server_id.0 as u32,
378                });
379            }
380
381            Ok(connections)
382        })
383        .await
384    }
385
386    async fn get_buffer_operation_serialization_version(
387        &self,
388        buffer_id: BufferId,
389        epoch: i32,
390        tx: &DatabaseTransaction,
391    ) -> Result<i32> {
392        Ok(buffer_snapshot::Entity::find()
393            .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
394            .filter(buffer_snapshot::Column::Epoch.eq(epoch))
395            .select_only()
396            .column(buffer_snapshot::Column::OperationSerializationVersion)
397            .into_values::<_, QueryOperationSerializationVersion>()
398            .one(&*tx)
399            .await?
400            .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
401    }
402
403    async fn get_channel_buffer(
404        &self,
405        channel_id: ChannelId,
406        tx: &DatabaseTransaction,
407    ) -> Result<buffer::Model> {
408        Ok(channel::Model {
409            id: channel_id,
410            ..Default::default()
411        }
412        .find_related(buffer::Entity)
413        .one(&*tx)
414        .await?
415        .ok_or_else(|| anyhow!("no such buffer"))?)
416    }
417
418    async fn get_buffer_state(
419        &self,
420        buffer: &buffer::Model,
421        tx: &DatabaseTransaction,
422    ) -> Result<(String, Vec<proto::Operation>)> {
423        let id = buffer.id;
424        let (base_text, version) = if buffer.epoch > 0 {
425            let snapshot = buffer_snapshot::Entity::find()
426                .filter(
427                    buffer_snapshot::Column::BufferId
428                        .eq(id)
429                        .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
430                )
431                .one(&*tx)
432                .await?
433                .ok_or_else(|| anyhow!("no such snapshot"))?;
434
435            let version = snapshot.operation_serialization_version;
436            (snapshot.text, version)
437        } else {
438            (String::new(), storage::SERIALIZATION_VERSION)
439        };
440
441        let mut rows = buffer_operation::Entity::find()
442            .filter(
443                buffer_operation::Column::BufferId
444                    .eq(id)
445                    .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
446            )
447            .stream(&*tx)
448            .await?;
449        let mut operations = Vec::new();
450        while let Some(row) = rows.next().await {
451            operations.push(proto::Operation {
452                variant: Some(operation_from_storage(row?, version)?),
453            })
454        }
455
456        Ok((base_text, operations))
457    }
458
459    async fn snapshot_channel_buffer(
460        &self,
461        channel_id: ChannelId,
462        tx: &DatabaseTransaction,
463    ) -> Result<()> {
464        let buffer = self.get_channel_buffer(channel_id, tx).await?;
465        let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
466        if operations.is_empty() {
467            return Ok(());
468        }
469
470        let mut text_buffer = text::Buffer::new(0, 0, base_text);
471        text_buffer
472            .apply_ops(operations.into_iter().filter_map(operation_from_wire))
473            .unwrap();
474
475        let base_text = text_buffer.text();
476        let epoch = buffer.epoch + 1;
477
478        buffer_snapshot::Model {
479            buffer_id: buffer.id,
480            epoch,
481            text: base_text,
482            operation_serialization_version: storage::SERIALIZATION_VERSION,
483        }
484        .into_active_model()
485        .insert(tx)
486        .await?;
487
488        buffer::ActiveModel {
489            id: ActiveValue::Unchanged(buffer.id),
490            epoch: ActiveValue::Set(epoch),
491            ..Default::default()
492        }
493        .save(tx)
494        .await?;
495
496        Ok(())
497    }
498}
499
500fn operation_to_storage(
501    operation: &proto::Operation,
502    buffer: &buffer::Model,
503    _format: i32,
504) -> Option<buffer_operation::ActiveModel> {
505    let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
506        proto::operation::Variant::Edit(operation) => (
507            operation.replica_id,
508            operation.lamport_timestamp,
509            storage::Operation {
510                version: version_to_storage(&operation.version),
511                is_undo: false,
512                edit_ranges: operation
513                    .ranges
514                    .iter()
515                    .map(|range| storage::Range {
516                        start: range.start,
517                        end: range.end,
518                    })
519                    .collect(),
520                edit_texts: operation.new_text.clone(),
521                undo_counts: Vec::new(),
522            },
523        ),
524        proto::operation::Variant::Undo(operation) => (
525            operation.replica_id,
526            operation.lamport_timestamp,
527            storage::Operation {
528                version: version_to_storage(&operation.version),
529                is_undo: true,
530                edit_ranges: Vec::new(),
531                edit_texts: Vec::new(),
532                undo_counts: operation
533                    .counts
534                    .iter()
535                    .map(|entry| storage::UndoCount {
536                        replica_id: entry.replica_id,
537                        lamport_timestamp: entry.lamport_timestamp,
538                        count: entry.count,
539                    })
540                    .collect(),
541            },
542        ),
543        _ => None?,
544    };
545
546    Some(buffer_operation::ActiveModel {
547        buffer_id: ActiveValue::Set(buffer.id),
548        epoch: ActiveValue::Set(buffer.epoch),
549        replica_id: ActiveValue::Set(replica_id as i32),
550        lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
551        value: ActiveValue::Set(value.encode_to_vec()),
552    })
553}
554
555fn operation_from_storage(
556    row: buffer_operation::Model,
557    _format_version: i32,
558) -> Result<proto::operation::Variant, Error> {
559    let operation =
560        storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
561    let version = version_from_storage(&operation.version);
562    Ok(if operation.is_undo {
563        proto::operation::Variant::Undo(proto::operation::Undo {
564            replica_id: row.replica_id as u32,
565            lamport_timestamp: row.lamport_timestamp as u32,
566            version,
567            counts: operation
568                .undo_counts
569                .iter()
570                .map(|entry| proto::UndoCount {
571                    replica_id: entry.replica_id,
572                    lamport_timestamp: entry.lamport_timestamp,
573                    count: entry.count,
574                })
575                .collect(),
576        })
577    } else {
578        proto::operation::Variant::Edit(proto::operation::Edit {
579            replica_id: row.replica_id as u32,
580            lamport_timestamp: row.lamport_timestamp as u32,
581            version,
582            ranges: operation
583                .edit_ranges
584                .into_iter()
585                .map(|range| proto::Range {
586                    start: range.start,
587                    end: range.end,
588                })
589                .collect(),
590            new_text: operation.edit_texts,
591        })
592    })
593}
594
595fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
596    version
597        .iter()
598        .map(|entry| storage::VectorClockEntry {
599            replica_id: entry.replica_id,
600            timestamp: entry.timestamp,
601        })
602        .collect()
603}
604
605fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
606    version
607        .iter()
608        .map(|entry| proto::VectorClockEntry {
609            replica_id: entry.replica_id,
610            timestamp: entry.timestamp,
611        })
612        .collect()
613}
614
615// This is currently a manual copy of the deserialization code in the client's langauge crate
616pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
617    match operation.variant? {
618        proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
619            timestamp: clock::Lamport {
620                replica_id: edit.replica_id as text::ReplicaId,
621                value: edit.lamport_timestamp,
622            },
623            version: version_from_wire(&edit.version),
624            ranges: edit
625                .ranges
626                .into_iter()
627                .map(|range| {
628                    text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
629                })
630                .collect(),
631            new_text: edit.new_text.into_iter().map(Arc::from).collect(),
632        })),
633        proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
634            timestamp: clock::Lamport {
635                replica_id: undo.replica_id as text::ReplicaId,
636                value: undo.lamport_timestamp,
637            },
638            version: version_from_wire(&undo.version),
639            counts: undo
640                .counts
641                .into_iter()
642                .map(|c| {
643                    (
644                        clock::Lamport {
645                            replica_id: c.replica_id as text::ReplicaId,
646                            value: c.lamport_timestamp,
647                        },
648                        c.count,
649                    )
650                })
651                .collect(),
652        })),
653        _ => None,
654    }
655}
656
657fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
658    let mut version = clock::Global::new();
659    for entry in message {
660        version.observe(clock::Lamport {
661            replica_id: entry.replica_id as text::ReplicaId,
662            value: entry.timestamp,
663        });
664    }
665    version
666}
667
668fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
669    let mut message = Vec::new();
670    for entry in version.iter() {
671        message.push(proto::VectorClockEntry {
672            replica_id: entry.replica_id as u32,
673            timestamp: entry.value,
674        });
675    }
676    message
677}
678
679#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
680enum QueryOperationSerializationVersion {
681    OperationSerializationVersion,
682}
683
684mod storage {
685    #![allow(non_snake_case)]
686    use prost::Message;
687    pub const SERIALIZATION_VERSION: i32 = 1;
688
689    #[derive(Message)]
690    pub struct Operation {
691        #[prost(message, repeated, tag = "2")]
692        pub version: Vec<VectorClockEntry>,
693        #[prost(bool, tag = "3")]
694        pub is_undo: bool,
695        #[prost(message, repeated, tag = "4")]
696        pub edit_ranges: Vec<Range>,
697        #[prost(string, repeated, tag = "5")]
698        pub edit_texts: Vec<String>,
699        #[prost(message, repeated, tag = "6")]
700        pub undo_counts: Vec<UndoCount>,
701    }
702
703    #[derive(Message)]
704    pub struct VectorClockEntry {
705        #[prost(uint32, tag = "1")]
706        pub replica_id: u32,
707        #[prost(uint32, tag = "2")]
708        pub timestamp: u32,
709    }
710
711    #[derive(Message)]
712    pub struct Range {
713        #[prost(uint64, tag = "1")]
714        pub start: u64,
715        #[prost(uint64, tag = "2")]
716        pub end: u64,
717    }
718
719    #[derive(Message)]
720    pub struct UndoCount {
721        #[prost(uint32, tag = "1")]
722        pub replica_id: u32,
723        #[prost(uint32, tag = "2")]
724        pub lamport_timestamp: u32,
725        #[prost(uint32, tag = "3")]
726        pub count: u32,
727    }
728}