buffers.rs

  1use super::*;
  2use prost::Message;
  3use text::{EditOperation, UndoOperation};
  4
  5impl Database {
  6    pub async fn join_channel_buffer(
  7        &self,
  8        channel_id: ChannelId,
  9        user_id: UserId,
 10        connection: ConnectionId,
 11    ) -> Result<proto::JoinChannelBufferResponse> {
 12        self.transaction(|tx| async move {
 13            let tx = tx;
 14
 15            self.check_user_is_channel_member(channel_id, user_id, &tx)
 16                .await?;
 17
 18            let buffer = channel::Model {
 19                id: channel_id,
 20                ..Default::default()
 21            }
 22            .find_related(buffer::Entity)
 23            .one(&*tx)
 24            .await?;
 25
 26            let buffer = if let Some(buffer) = buffer {
 27                buffer
 28            } else {
 29                let buffer = buffer::ActiveModel {
 30                    channel_id: ActiveValue::Set(channel_id),
 31                    ..Default::default()
 32                }
 33                .insert(&*tx)
 34                .await?;
 35                buffer_snapshot::ActiveModel {
 36                    buffer_id: ActiveValue::Set(buffer.id),
 37                    epoch: ActiveValue::Set(0),
 38                    text: ActiveValue::Set(String::new()),
 39                    operation_serialization_version: ActiveValue::Set(
 40                        storage::SERIALIZATION_VERSION,
 41                    ),
 42                }
 43                .insert(&*tx)
 44                .await?;
 45                buffer
 46            };
 47
 48            // Join the collaborators
 49            let mut collaborators = channel_buffer_collaborator::Entity::find()
 50                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 51                .all(&*tx)
 52                .await?;
 53            let replica_ids = collaborators
 54                .iter()
 55                .map(|c| c.replica_id)
 56                .collect::<HashSet<_>>();
 57            let mut replica_id = ReplicaId(0);
 58            while replica_ids.contains(&replica_id) {
 59                replica_id.0 += 1;
 60            }
 61            let collaborator = channel_buffer_collaborator::ActiveModel {
 62                channel_id: ActiveValue::Set(channel_id),
 63                connection_id: ActiveValue::Set(connection.id as i32),
 64                connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
 65                user_id: ActiveValue::Set(user_id),
 66                replica_id: ActiveValue::Set(replica_id),
 67                ..Default::default()
 68            }
 69            .insert(&*tx)
 70            .await?;
 71            collaborators.push(collaborator);
 72
 73            // Assemble the buffer state
 74            let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
 75
 76            Ok(proto::JoinChannelBufferResponse {
 77                buffer_id: buffer.id.to_proto(),
 78                replica_id: replica_id.to_proto() as u32,
 79                base_text,
 80                operations,
 81                collaborators: collaborators
 82                    .into_iter()
 83                    .map(|collaborator| proto::Collaborator {
 84                        peer_id: Some(collaborator.connection().into()),
 85                        user_id: collaborator.user_id.to_proto(),
 86                        replica_id: collaborator.replica_id.0 as u32,
 87                    })
 88                    .collect(),
 89            })
 90        })
 91        .await
 92    }
 93
 94    pub async fn leave_channel_buffer(
 95        &self,
 96        channel_id: ChannelId,
 97        connection: ConnectionId,
 98    ) -> Result<Vec<ConnectionId>> {
 99        self.transaction(|tx| async move {
100            self.leave_channel_buffer_internal(channel_id, connection, &*tx)
101                .await
102        })
103        .await
104    }
105
106    pub async fn leave_channel_buffer_internal(
107        &self,
108        channel_id: ChannelId,
109        connection: ConnectionId,
110        tx: &DatabaseTransaction,
111    ) -> Result<Vec<ConnectionId>> {
112        let result = channel_buffer_collaborator::Entity::delete_many()
113            .filter(
114                Condition::all()
115                    .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
116                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
117                    .add(
118                        channel_buffer_collaborator::Column::ConnectionServerId
119                            .eq(connection.owner_id as i32),
120                    ),
121            )
122            .exec(&*tx)
123            .await?;
124        if result.rows_affected == 0 {
125            Err(anyhow!("not a collaborator on this project"))?;
126        }
127
128        let mut connections = Vec::new();
129        let mut rows = channel_buffer_collaborator::Entity::find()
130            .filter(
131                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
132            )
133            .stream(&*tx)
134            .await?;
135        while let Some(row) = rows.next().await {
136            let row = row?;
137            connections.push(ConnectionId {
138                id: row.connection_id as u32,
139                owner_id: row.connection_server_id.0 as u32,
140            });
141        }
142
143        drop(rows);
144
145        if connections.is_empty() {
146            self.snapshot_buffer(channel_id, &tx).await?;
147        }
148
149        Ok(connections)
150    }
151
152    pub async fn leave_channel_buffers(
153        &self,
154        connection: ConnectionId,
155    ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
156        self.transaction(|tx| async move {
157            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
158            enum QueryChannelIds {
159                ChannelId,
160            }
161
162            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
163                .select_only()
164                .column(channel_buffer_collaborator::Column::ChannelId)
165                .filter(Condition::all().add(
166                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
167                ))
168                .into_values::<_, QueryChannelIds>()
169                .all(&*tx)
170                .await?;
171
172            let mut result = Vec::new();
173            for channel_id in channel_ids {
174                let collaborators = self
175                    .leave_channel_buffer_internal(channel_id, connection, &*tx)
176                    .await?;
177                result.push((channel_id, collaborators));
178            }
179
180            Ok(result)
181        })
182        .await
183    }
184
185    pub async fn get_channel_buffer_collaborators(
186        &self,
187        channel_id: ChannelId,
188    ) -> Result<Vec<UserId>> {
189        self.transaction(|tx| async move {
190            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
191            enum QueryUserIds {
192                UserId,
193            }
194
195            let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
196                .select_only()
197                .column(channel_buffer_collaborator::Column::UserId)
198                .filter(
199                    Condition::all()
200                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
201                )
202                .into_values::<_, QueryUserIds>()
203                .all(&*tx)
204                .await?;
205
206            Ok(users)
207        })
208        .await
209    }
210
211    pub async fn update_channel_buffer(
212        &self,
213        channel_id: ChannelId,
214        user: UserId,
215        operations: &[proto::Operation],
216    ) -> Result<Vec<ConnectionId>> {
217        self.transaction(move |tx| async move {
218            self.check_user_is_channel_member(channel_id, user, &*tx)
219                .await?;
220
221            let buffer = buffer::Entity::find()
222                .filter(buffer::Column::ChannelId.eq(channel_id))
223                .one(&*tx)
224                .await?
225                .ok_or_else(|| anyhow!("no such buffer"))?;
226
227            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
228            enum QueryVersion {
229                OperationSerializationVersion,
230            }
231
232            let serialization_version: i32 = buffer
233                .find_related(buffer_snapshot::Entity)
234                .select_only()
235                .column(buffer_snapshot::Column::OperationSerializationVersion)
236                .filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch))
237                .into_values::<_, QueryVersion>()
238                .one(&*tx)
239                .await?
240                .ok_or_else(|| anyhow!("missing buffer snapshot"))?;
241
242            let operations = operations
243                .iter()
244                .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
245                .collect::<Vec<_>>();
246            if !operations.is_empty() {
247                buffer_operation::Entity::insert_many(operations)
248                    .exec(&*tx)
249                    .await?;
250            }
251
252            let mut connections = Vec::new();
253            let mut rows = channel_buffer_collaborator::Entity::find()
254                .filter(
255                    Condition::all()
256                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
257                )
258                .stream(&*tx)
259                .await?;
260            while let Some(row) = rows.next().await {
261                let row = row?;
262                connections.push(ConnectionId {
263                    id: row.connection_id as u32,
264                    owner_id: row.connection_server_id.0 as u32,
265                });
266            }
267
268            Ok(connections)
269        })
270        .await
271    }
272
273    async fn get_buffer_state(
274        &self,
275        buffer: &buffer::Model,
276        tx: &DatabaseTransaction,
277    ) -> Result<(String, Vec<proto::Operation>)> {
278        let id = buffer.id;
279        let (base_text, version) = if buffer.epoch > 0 {
280            let snapshot = buffer_snapshot::Entity::find()
281                .filter(
282                    buffer_snapshot::Column::BufferId
283                        .eq(id)
284                        .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
285                )
286                .one(&*tx)
287                .await?
288                .ok_or_else(|| anyhow!("no such snapshot"))?;
289
290            let version = snapshot.operation_serialization_version;
291            (snapshot.text, version)
292        } else {
293            (String::new(), storage::SERIALIZATION_VERSION)
294        };
295
296        let mut rows = buffer_operation::Entity::find()
297            .filter(
298                buffer_operation::Column::BufferId
299                    .eq(id)
300                    .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
301            )
302            .stream(&*tx)
303            .await?;
304        let mut operations = Vec::new();
305        while let Some(row) = rows.next().await {
306            let row = row?;
307
308            let operation = operation_from_storage(row, version)?;
309            operations.push(proto::Operation {
310                variant: Some(operation),
311            })
312        }
313
314        Ok((base_text, operations))
315    }
316
317    async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> {
318        let buffer = channel::Model {
319            id: channel_id,
320            ..Default::default()
321        }
322        .find_related(buffer::Entity)
323        .one(&*tx)
324        .await?
325        .ok_or_else(|| anyhow!("no such buffer"))?;
326
327        let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
328        if operations.is_empty() {
329            return Ok(());
330        }
331
332        let mut text_buffer = text::Buffer::new(0, 0, base_text);
333        text_buffer
334            .apply_ops(operations.into_iter().filter_map(operation_from_wire))
335            .unwrap();
336
337        let base_text = text_buffer.text();
338        let epoch = buffer.epoch + 1;
339
340        buffer_snapshot::Model {
341            buffer_id: buffer.id,
342            epoch,
343            text: base_text,
344            operation_serialization_version: storage::SERIALIZATION_VERSION,
345        }
346        .into_active_model()
347        .insert(tx)
348        .await?;
349
350        buffer::ActiveModel {
351            id: ActiveValue::Unchanged(buffer.id),
352            epoch: ActiveValue::Set(epoch),
353            ..Default::default()
354        }
355        .save(tx)
356        .await?;
357
358        Ok(())
359    }
360}
361
362fn operation_to_storage(
363    operation: &proto::Operation,
364    buffer: &buffer::Model,
365    _format: i32,
366) -> Option<buffer_operation::ActiveModel> {
367    let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
368        proto::operation::Variant::Edit(operation) => (
369            operation.replica_id,
370            operation.lamport_timestamp,
371            storage::Operation {
372                version: version_to_storage(&operation.version),
373                is_undo: false,
374                edit_ranges: operation
375                    .ranges
376                    .iter()
377                    .map(|range| storage::Range {
378                        start: range.start,
379                        end: range.end,
380                    })
381                    .collect(),
382                edit_texts: operation.new_text.clone(),
383                undo_counts: Vec::new(),
384            },
385        ),
386        proto::operation::Variant::Undo(operation) => (
387            operation.replica_id,
388            operation.lamport_timestamp,
389            storage::Operation {
390                version: version_to_storage(&operation.version),
391                is_undo: true,
392                edit_ranges: Vec::new(),
393                edit_texts: Vec::new(),
394                undo_counts: operation
395                    .counts
396                    .iter()
397                    .map(|entry| storage::UndoCount {
398                        replica_id: entry.replica_id,
399                        lamport_timestamp: entry.lamport_timestamp,
400                        count: entry.count,
401                    })
402                    .collect(),
403            },
404        ),
405        _ => None?,
406    };
407
408    Some(buffer_operation::ActiveModel {
409        buffer_id: ActiveValue::Set(buffer.id),
410        epoch: ActiveValue::Set(buffer.epoch),
411        replica_id: ActiveValue::Set(replica_id as i32),
412        lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
413        value: ActiveValue::Set(value.encode_to_vec()),
414    })
415}
416
417fn operation_from_storage(
418    row: buffer_operation::Model,
419    _format_version: i32,
420) -> Result<proto::operation::Variant, Error> {
421    let operation =
422        storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
423    let version = version_from_storage(&operation.version);
424    Ok(if operation.is_undo {
425        proto::operation::Variant::Undo(proto::operation::Undo {
426            replica_id: row.replica_id as u32,
427            lamport_timestamp: row.lamport_timestamp as u32,
428            version,
429            counts: operation
430                .undo_counts
431                .iter()
432                .map(|entry| proto::UndoCount {
433                    replica_id: entry.replica_id,
434                    lamport_timestamp: entry.lamport_timestamp,
435                    count: entry.count,
436                })
437                .collect(),
438        })
439    } else {
440        proto::operation::Variant::Edit(proto::operation::Edit {
441            replica_id: row.replica_id as u32,
442            lamport_timestamp: row.lamport_timestamp as u32,
443            version,
444            ranges: operation
445                .edit_ranges
446                .into_iter()
447                .map(|range| proto::Range {
448                    start: range.start,
449                    end: range.end,
450                })
451                .collect(),
452            new_text: operation.edit_texts,
453        })
454    })
455}
456
457fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
458    version
459        .iter()
460        .map(|entry| storage::VectorClockEntry {
461            replica_id: entry.replica_id,
462            timestamp: entry.timestamp,
463        })
464        .collect()
465}
466
467fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
468    version
469        .iter()
470        .map(|entry| proto::VectorClockEntry {
471            replica_id: entry.replica_id,
472            timestamp: entry.timestamp,
473        })
474        .collect()
475}
476
477// This is currently a manual copy of the deserialization code in the client's langauge crate
478pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
479    match operation.variant? {
480        proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
481            timestamp: clock::Lamport {
482                replica_id: edit.replica_id as text::ReplicaId,
483                value: edit.lamport_timestamp,
484            },
485            version: version_from_wire(&edit.version),
486            ranges: edit
487                .ranges
488                .into_iter()
489                .map(|range| {
490                    text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
491                })
492                .collect(),
493            new_text: edit.new_text.into_iter().map(Arc::from).collect(),
494        })),
495        proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
496            timestamp: clock::Lamport {
497                replica_id: undo.replica_id as text::ReplicaId,
498                value: undo.lamport_timestamp,
499            },
500            version: version_from_wire(&undo.version),
501            counts: undo
502                .counts
503                .into_iter()
504                .map(|c| {
505                    (
506                        clock::Lamport {
507                            replica_id: c.replica_id as text::ReplicaId,
508                            value: c.lamport_timestamp,
509                        },
510                        c.count,
511                    )
512                })
513                .collect(),
514        })),
515        _ => None,
516    }
517}
518
519fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
520    let mut version = clock::Global::new();
521    for entry in message {
522        version.observe(clock::Lamport {
523            replica_id: entry.replica_id as text::ReplicaId,
524            value: entry.timestamp,
525        });
526    }
527    version
528}
529
530mod storage {
531    #![allow(non_snake_case)]
532    use prost::Message;
533    pub const SERIALIZATION_VERSION: i32 = 1;
534
535    #[derive(Message)]
536    pub struct Operation {
537        #[prost(message, repeated, tag = "2")]
538        pub version: Vec<VectorClockEntry>,
539        #[prost(bool, tag = "3")]
540        pub is_undo: bool,
541        #[prost(message, repeated, tag = "4")]
542        pub edit_ranges: Vec<Range>,
543        #[prost(string, repeated, tag = "5")]
544        pub edit_texts: Vec<String>,
545        #[prost(message, repeated, tag = "6")]
546        pub undo_counts: Vec<UndoCount>,
547    }
548
549    #[derive(Message)]
550    pub struct VectorClockEntry {
551        #[prost(uint32, tag = "1")]
552        pub replica_id: u32,
553        #[prost(uint32, tag = "2")]
554        pub timestamp: u32,
555    }
556
557    #[derive(Message)]
558    pub struct Range {
559        #[prost(uint64, tag = "1")]
560        pub start: u64,
561        #[prost(uint64, tag = "2")]
562        pub end: u64,
563    }
564
565    #[derive(Message)]
566    pub struct UndoCount {
567        #[prost(uint32, tag = "1")]
568        pub replica_id: u32,
569        #[prost(uint32, tag = "2")]
570        pub lamport_timestamp: u32,
571        #[prost(uint32, tag = "3")]
572        pub count: u32,
573    }
574}