buffers.rs

  1use super::*;
  2use prost::Message;
  3use text::{EditOperation, InsertionTimestamp, UndoOperation};
  4
  5impl Database {
  6    pub async fn join_channel_buffer(
  7        &self,
  8        channel_id: ChannelId,
  9        user_id: UserId,
 10        connection: ConnectionId,
 11    ) -> Result<proto::JoinChannelBufferResponse> {
 12        self.transaction(|tx| async move {
 13            let tx = tx;
 14
 15            self.check_user_is_channel_member(channel_id, user_id, &tx)
 16                .await?;
 17
 18            let buffer = channel::Model {
 19                id: channel_id,
 20                ..Default::default()
 21            }
 22            .find_related(buffer::Entity)
 23            .one(&*tx)
 24            .await?;
 25
 26            let buffer = if let Some(buffer) = buffer {
 27                buffer
 28            } else {
 29                let buffer = buffer::ActiveModel {
 30                    channel_id: ActiveValue::Set(channel_id),
 31                    ..Default::default()
 32                }
 33                .insert(&*tx)
 34                .await?;
 35                buffer_snapshot::ActiveModel {
 36                    buffer_id: ActiveValue::Set(buffer.id),
 37                    epoch: ActiveValue::Set(0),
 38                    text: ActiveValue::Set(String::new()),
 39                    operation_serialization_version: ActiveValue::Set(
 40                        storage::SERIALIZATION_VERSION,
 41                    ),
 42                }
 43                .insert(&*tx)
 44                .await?;
 45                buffer
 46            };
 47
 48            // Join the collaborators
 49            let mut collaborators = channel_buffer_collaborator::Entity::find()
 50                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 51                .all(&*tx)
 52                .await?;
 53            let replica_ids = collaborators
 54                .iter()
 55                .map(|c| c.replica_id)
 56                .collect::<HashSet<_>>();
 57            let mut replica_id = ReplicaId(0);
 58            while replica_ids.contains(&replica_id) {
 59                replica_id.0 += 1;
 60            }
 61            let collaborator = channel_buffer_collaborator::ActiveModel {
 62                channel_id: ActiveValue::Set(channel_id),
 63                connection_id: ActiveValue::Set(connection.id as i32),
 64                connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
 65                user_id: ActiveValue::Set(user_id),
 66                replica_id: ActiveValue::Set(replica_id),
 67                ..Default::default()
 68            }
 69            .insert(&*tx)
 70            .await?;
 71            collaborators.push(collaborator);
 72
 73            // Assemble the buffer state
 74            let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
 75
 76            Ok(proto::JoinChannelBufferResponse {
 77                buffer_id: buffer.id.to_proto(),
 78                replica_id: replica_id.to_proto() as u32,
 79                base_text,
 80                operations,
 81                collaborators: collaborators
 82                    .into_iter()
 83                    .map(|collaborator| proto::Collaborator {
 84                        peer_id: Some(collaborator.connection().into()),
 85                        user_id: collaborator.user_id.to_proto(),
 86                        replica_id: collaborator.replica_id.0 as u32,
 87                    })
 88                    .collect(),
 89            })
 90        })
 91        .await
 92    }
 93
 94    pub async fn leave_channel_buffer(
 95        &self,
 96        channel_id: ChannelId,
 97        connection: ConnectionId,
 98    ) -> Result<Vec<ConnectionId>> {
 99        self.transaction(|tx| async move {
100            self.leave_channel_buffer_internal(channel_id, connection, &*tx)
101                .await
102        })
103        .await
104    }
105
106    pub async fn leave_channel_buffer_internal(
107        &self,
108        channel_id: ChannelId,
109        connection: ConnectionId,
110        tx: &DatabaseTransaction,
111    ) -> Result<Vec<ConnectionId>> {
112        let result = channel_buffer_collaborator::Entity::delete_many()
113            .filter(
114                Condition::all()
115                    .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
116                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
117                    .add(
118                        channel_buffer_collaborator::Column::ConnectionServerId
119                            .eq(connection.owner_id as i32),
120                    ),
121            )
122            .exec(&*tx)
123            .await?;
124        if result.rows_affected == 0 {
125            Err(anyhow!("not a collaborator on this project"))?;
126        }
127
128        let mut connections = Vec::new();
129        let mut rows = channel_buffer_collaborator::Entity::find()
130            .filter(
131                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
132            )
133            .stream(&*tx)
134            .await?;
135        while let Some(row) = rows.next().await {
136            let row = row?;
137            connections.push(ConnectionId {
138                id: row.connection_id as u32,
139                owner_id: row.connection_server_id.0 as u32,
140            });
141        }
142
143        drop(rows);
144
145        if connections.is_empty() {
146            self.snapshot_buffer(channel_id, &tx).await?;
147        }
148
149        Ok(connections)
150    }
151
152    pub async fn leave_channel_buffers(
153        &self,
154        connection: ConnectionId,
155    ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
156        self.transaction(|tx| async move {
157            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
158            enum QueryChannelIds {
159                ChannelId,
160            }
161
162            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
163                .select_only()
164                .column(channel_buffer_collaborator::Column::ChannelId)
165                .filter(Condition::all().add(
166                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
167                ))
168                .into_values::<_, QueryChannelIds>()
169                .all(&*tx)
170                .await?;
171
172            let mut result = Vec::new();
173            for channel_id in channel_ids {
174                let collaborators = self
175                    .leave_channel_buffer_internal(channel_id, connection, &*tx)
176                    .await?;
177                result.push((channel_id, collaborators));
178            }
179
180            Ok(result)
181        })
182        .await
183    }
184
185    #[cfg(debug_assertions)]
186    pub async fn get_channel_buffer_collaborators(
187        &self,
188        channel_id: ChannelId,
189    ) -> Result<Vec<UserId>> {
190        self.transaction(|tx| async move {
191            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
192            enum QueryUserIds {
193                UserId,
194            }
195
196            let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
197                .select_only()
198                .column(channel_buffer_collaborator::Column::UserId)
199                .filter(
200                    Condition::all()
201                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
202                )
203                .into_values::<_, QueryUserIds>()
204                .all(&*tx)
205                .await?;
206
207            Ok(users)
208        })
209        .await
210    }
211
212    pub async fn update_channel_buffer(
213        &self,
214        channel_id: ChannelId,
215        user: UserId,
216        operations: &[proto::Operation],
217    ) -> Result<Vec<ConnectionId>> {
218        self.transaction(move |tx| async move {
219            self.check_user_is_channel_member(channel_id, user, &*tx)
220                .await?;
221
222            let buffer = buffer::Entity::find()
223                .filter(buffer::Column::ChannelId.eq(channel_id))
224                .one(&*tx)
225                .await?
226                .ok_or_else(|| anyhow!("no such buffer"))?;
227
228            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
229            enum QueryVersion {
230                OperationSerializationVersion,
231            }
232
233            let serialization_version: i32 = buffer
234                .find_related(buffer_snapshot::Entity)
235                .select_only()
236                .column(buffer_snapshot::Column::OperationSerializationVersion)
237                .filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch))
238                .into_values::<_, QueryVersion>()
239                .one(&*tx)
240                .await?
241                .ok_or_else(|| anyhow!("missing buffer snapshot"))?;
242
243            let operations = operations
244                .iter()
245                .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
246                .collect::<Vec<_>>();
247            if !operations.is_empty() {
248                buffer_operation::Entity::insert_many(operations)
249                    .exec(&*tx)
250                    .await?;
251            }
252
253            let mut connections = Vec::new();
254            let mut rows = channel_buffer_collaborator::Entity::find()
255                .filter(
256                    Condition::all()
257                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
258                )
259                .stream(&*tx)
260                .await?;
261            while let Some(row) = rows.next().await {
262                let row = row?;
263                connections.push(ConnectionId {
264                    id: row.connection_id as u32,
265                    owner_id: row.connection_server_id.0 as u32,
266                });
267            }
268
269            Ok(connections)
270        })
271        .await
272    }
273
274    async fn get_buffer_state(
275        &self,
276        buffer: &buffer::Model,
277        tx: &DatabaseTransaction,
278    ) -> Result<(String, Vec<proto::Operation>)> {
279        let id = buffer.id;
280        let (base_text, version) = if buffer.epoch > 0 {
281            let snapshot = buffer_snapshot::Entity::find()
282                .filter(
283                    buffer_snapshot::Column::BufferId
284                        .eq(id)
285                        .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
286                )
287                .one(&*tx)
288                .await?
289                .ok_or_else(|| anyhow!("no such snapshot"))?;
290
291            let version = snapshot.operation_serialization_version;
292            (snapshot.text, version)
293        } else {
294            (String::new(), storage::SERIALIZATION_VERSION)
295        };
296
297        let mut rows = buffer_operation::Entity::find()
298            .filter(
299                buffer_operation::Column::BufferId
300                    .eq(id)
301                    .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
302            )
303            .stream(&*tx)
304            .await?;
305        let mut operations = Vec::new();
306        while let Some(row) = rows.next().await {
307            let row = row?;
308
309            let operation = operation_from_storage(row, version)?;
310            operations.push(proto::Operation {
311                variant: Some(operation),
312            })
313        }
314
315        Ok((base_text, operations))
316    }
317
318    async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> {
319        let buffer = channel::Model {
320            id: channel_id,
321            ..Default::default()
322        }
323        .find_related(buffer::Entity)
324        .one(&*tx)
325        .await?
326        .ok_or_else(|| anyhow!("no such buffer"))?;
327
328        let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
329        if operations.is_empty() {
330            return Ok(());
331        }
332
333        let mut text_buffer = text::Buffer::new(0, 0, base_text);
334        text_buffer
335            .apply_ops(operations.into_iter().filter_map(operation_from_wire))
336            .unwrap();
337
338        let base_text = text_buffer.text();
339        let epoch = buffer.epoch + 1;
340
341        buffer_snapshot::Model {
342            buffer_id: buffer.id,
343            epoch,
344            text: base_text,
345            operation_serialization_version: storage::SERIALIZATION_VERSION,
346        }
347        .into_active_model()
348        .insert(tx)
349        .await?;
350
351        buffer::ActiveModel {
352            id: ActiveValue::Unchanged(buffer.id),
353            epoch: ActiveValue::Set(epoch),
354            ..Default::default()
355        }
356        .save(tx)
357        .await?;
358
359        Ok(())
360    }
361}
362
363fn operation_to_storage(
364    operation: &proto::Operation,
365    buffer: &buffer::Model,
366    _format: i32,
367) -> Option<buffer_operation::ActiveModel> {
368    let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
369        proto::operation::Variant::Edit(operation) => (
370            operation.replica_id,
371            operation.lamport_timestamp,
372            storage::Operation {
373                local_timestamp: operation.local_timestamp,
374                version: version_to_storage(&operation.version),
375                is_undo: false,
376                edit_ranges: operation
377                    .ranges
378                    .iter()
379                    .map(|range| storage::Range {
380                        start: range.start,
381                        end: range.end,
382                    })
383                    .collect(),
384                edit_texts: operation.new_text.clone(),
385                undo_counts: Vec::new(),
386            },
387        ),
388        proto::operation::Variant::Undo(operation) => (
389            operation.replica_id,
390            operation.lamport_timestamp,
391            storage::Operation {
392                local_timestamp: operation.local_timestamp,
393                version: version_to_storage(&operation.version),
394                is_undo: true,
395                edit_ranges: Vec::new(),
396                edit_texts: Vec::new(),
397                undo_counts: operation
398                    .counts
399                    .iter()
400                    .map(|entry| storage::UndoCount {
401                        replica_id: entry.replica_id,
402                        local_timestamp: entry.local_timestamp,
403                        count: entry.count,
404                    })
405                    .collect(),
406            },
407        ),
408        _ => None?,
409    };
410
411    Some(buffer_operation::ActiveModel {
412        buffer_id: ActiveValue::Set(buffer.id),
413        epoch: ActiveValue::Set(buffer.epoch),
414        replica_id: ActiveValue::Set(replica_id as i32),
415        lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
416        value: ActiveValue::Set(value.encode_to_vec()),
417    })
418}
419
420fn operation_from_storage(
421    row: buffer_operation::Model,
422    _format_version: i32,
423) -> Result<proto::operation::Variant, Error> {
424    let operation =
425        storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
426    let version = version_from_storage(&operation.version);
427    Ok(if operation.is_undo {
428        proto::operation::Variant::Undo(proto::operation::Undo {
429            replica_id: row.replica_id as u32,
430            local_timestamp: operation.local_timestamp as u32,
431            lamport_timestamp: row.lamport_timestamp as u32,
432            version,
433            counts: operation
434                .undo_counts
435                .iter()
436                .map(|entry| proto::UndoCount {
437                    replica_id: entry.replica_id,
438                    local_timestamp: entry.local_timestamp,
439                    count: entry.count,
440                })
441                .collect(),
442        })
443    } else {
444        proto::operation::Variant::Edit(proto::operation::Edit {
445            replica_id: row.replica_id as u32,
446            local_timestamp: operation.local_timestamp as u32,
447            lamport_timestamp: row.lamport_timestamp as u32,
448            version,
449            ranges: operation
450                .edit_ranges
451                .into_iter()
452                .map(|range| proto::Range {
453                    start: range.start,
454                    end: range.end,
455                })
456                .collect(),
457            new_text: operation.edit_texts,
458        })
459    })
460}
461
462fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
463    version
464        .iter()
465        .map(|entry| storage::VectorClockEntry {
466            replica_id: entry.replica_id,
467            timestamp: entry.timestamp,
468        })
469        .collect()
470}
471
472fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
473    version
474        .iter()
475        .map(|entry| proto::VectorClockEntry {
476            replica_id: entry.replica_id,
477            timestamp: entry.timestamp,
478        })
479        .collect()
480}
481
482// This is currently a manual copy of the deserialization code in the client's langauge crate
483pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
484    match operation.variant? {
485        proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
486            timestamp: InsertionTimestamp {
487                replica_id: edit.replica_id as text::ReplicaId,
488                local: edit.local_timestamp,
489                lamport: edit.lamport_timestamp,
490            },
491            version: version_from_wire(&edit.version),
492            ranges: edit
493                .ranges
494                .into_iter()
495                .map(|range| {
496                    text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
497                })
498                .collect(),
499            new_text: edit.new_text.into_iter().map(Arc::from).collect(),
500        })),
501        proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo {
502            lamport_timestamp: clock::Lamport {
503                replica_id: undo.replica_id as text::ReplicaId,
504                value: undo.lamport_timestamp,
505            },
506            undo: UndoOperation {
507                id: clock::Local {
508                    replica_id: undo.replica_id as text::ReplicaId,
509                    value: undo.local_timestamp,
510                },
511                version: version_from_wire(&undo.version),
512                counts: undo
513                    .counts
514                    .into_iter()
515                    .map(|c| {
516                        (
517                            clock::Local {
518                                replica_id: c.replica_id as text::ReplicaId,
519                                value: c.local_timestamp,
520                            },
521                            c.count,
522                        )
523                    })
524                    .collect(),
525            },
526        }),
527        _ => None,
528    }
529}
530
531fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
532    let mut version = clock::Global::new();
533    for entry in message {
534        version.observe(clock::Local {
535            replica_id: entry.replica_id as text::ReplicaId,
536            value: entry.timestamp,
537        });
538    }
539    version
540}
541
542mod storage {
543    #![allow(non_snake_case)]
544    use prost::Message;
545    pub const SERIALIZATION_VERSION: i32 = 1;
546
547    #[derive(Message)]
548    pub struct Operation {
549        #[prost(uint32, tag = "1")]
550        pub local_timestamp: u32,
551        #[prost(message, repeated, tag = "2")]
552        pub version: Vec<VectorClockEntry>,
553        #[prost(bool, tag = "3")]
554        pub is_undo: bool,
555        #[prost(message, repeated, tag = "4")]
556        pub edit_ranges: Vec<Range>,
557        #[prost(string, repeated, tag = "5")]
558        pub edit_texts: Vec<String>,
559        #[prost(message, repeated, tag = "6")]
560        pub undo_counts: Vec<UndoCount>,
561    }
562
563    #[derive(Message)]
564    pub struct VectorClockEntry {
565        #[prost(uint32, tag = "1")]
566        pub replica_id: u32,
567        #[prost(uint32, tag = "2")]
568        pub timestamp: u32,
569    }
570
571    #[derive(Message)]
572    pub struct Range {
573        #[prost(uint64, tag = "1")]
574        pub start: u64,
575        #[prost(uint64, tag = "2")]
576        pub end: u64,
577    }
578
579    #[derive(Message)]
580    pub struct UndoCount {
581        #[prost(uint32, tag = "1")]
582        pub replica_id: u32,
583        #[prost(uint32, tag = "2")]
584        pub local_timestamp: u32,
585        #[prost(uint32, tag = "3")]
586        pub count: u32,
587    }
588}