1use super::*;
2use prost::Message;
3use text::{EditOperation, UndoOperation};
4
5impl Database {
6 pub async fn join_channel_buffer(
7 &self,
8 channel_id: ChannelId,
9 user_id: UserId,
10 connection: ConnectionId,
11 ) -> Result<proto::JoinChannelBufferResponse> {
12 self.transaction(|tx| async move {
13 self.check_user_is_channel_member(channel_id, user_id, &tx)
14 .await?;
15
16 let buffer = channel::Model {
17 id: channel_id,
18 ..Default::default()
19 }
20 .find_related(buffer::Entity)
21 .one(&*tx)
22 .await?;
23
24 let buffer = if let Some(buffer) = buffer {
25 buffer
26 } else {
27 let buffer = buffer::ActiveModel {
28 channel_id: ActiveValue::Set(channel_id),
29 ..Default::default()
30 }
31 .insert(&*tx)
32 .await?;
33 buffer_snapshot::ActiveModel {
34 buffer_id: ActiveValue::Set(buffer.id),
35 epoch: ActiveValue::Set(0),
36 text: ActiveValue::Set(String::new()),
37 operation_serialization_version: ActiveValue::Set(
38 storage::SERIALIZATION_VERSION,
39 ),
40 }
41 .insert(&*tx)
42 .await?;
43 buffer
44 };
45
46 // Join the collaborators
47 let mut collaborators = channel_buffer_collaborator::Entity::find()
48 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
49 .all(&*tx)
50 .await?;
51 let replica_ids = collaborators
52 .iter()
53 .map(|c| c.replica_id)
54 .collect::<HashSet<_>>();
55 let mut replica_id = ReplicaId(0);
56 while replica_ids.contains(&replica_id) {
57 replica_id.0 += 1;
58 }
59 let collaborator = channel_buffer_collaborator::ActiveModel {
60 channel_id: ActiveValue::Set(channel_id),
61 connection_id: ActiveValue::Set(connection.id as i32),
62 connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
63 user_id: ActiveValue::Set(user_id),
64 replica_id: ActiveValue::Set(replica_id),
65 ..Default::default()
66 }
67 .insert(&*tx)
68 .await?;
69 collaborators.push(collaborator);
70
71 let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
72
73 Ok(proto::JoinChannelBufferResponse {
74 buffer_id: buffer.id.to_proto(),
75 replica_id: replica_id.to_proto() as u32,
76 base_text,
77 operations,
78 epoch: buffer.epoch as u64,
79 collaborators: collaborators
80 .into_iter()
81 .map(|collaborator| proto::Collaborator {
82 peer_id: Some(collaborator.connection().into()),
83 user_id: collaborator.user_id.to_proto(),
84 replica_id: collaborator.replica_id.0 as u32,
85 })
86 .collect(),
87 })
88 })
89 .await
90 }
91
92 pub async fn rejoin_channel_buffers(
93 &self,
94 buffers: &[proto::ChannelBufferVersion],
95 user_id: UserId,
96 connection_id: ConnectionId,
97 ) -> Result<Vec<RejoinedChannelBuffer>> {
98 self.transaction(|tx| async move {
99 let mut results = Vec::new();
100 for client_buffer in buffers {
101 let channel_id = ChannelId::from_proto(client_buffer.channel_id);
102 if self
103 .check_user_is_channel_member(channel_id, user_id, &*tx)
104 .await
105 .is_err()
106 {
107 log::info!("user is not a member of channel");
108 continue;
109 }
110
111 let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
112 let mut collaborators = channel_buffer_collaborator::Entity::find()
113 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
114 .all(&*tx)
115 .await?;
116
117 // If the buffer epoch hasn't changed since the client lost
118 // connection, then the client's buffer can be syncronized with
119 // the server's buffer.
120 if buffer.epoch as u64 != client_buffer.epoch {
121 log::info!("can't rejoin buffer, epoch has changed");
122 continue;
123 }
124
125 // Find the collaborator record for this user's previous lost
126 // connection. Update it with the new connection id.
127 let server_id = ServerId(connection_id.owner_id as i32);
128 let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
129 c.user_id == user_id
130 && (c.connection_lost || c.connection_server_id != server_id)
131 }) else {
132 log::info!("can't rejoin buffer, no previous collaborator found");
133 continue;
134 };
135 let old_connection_id = self_collaborator.connection();
136 *self_collaborator = channel_buffer_collaborator::ActiveModel {
137 id: ActiveValue::Unchanged(self_collaborator.id),
138 connection_id: ActiveValue::Set(connection_id.id as i32),
139 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
140 connection_lost: ActiveValue::Set(false),
141 ..Default::default()
142 }
143 .update(&*tx)
144 .await?;
145
146 let client_version = version_from_wire(&client_buffer.version);
147 let serialization_version = self
148 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
149 .await?;
150
151 let mut rows = buffer_operation::Entity::find()
152 .filter(
153 buffer_operation::Column::BufferId
154 .eq(buffer.id)
155 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
156 )
157 .stream(&*tx)
158 .await?;
159
160 // Find the server's version vector and any operations
161 // that the client has not seen.
162 let mut server_version = clock::Global::new();
163 let mut operations = Vec::new();
164 while let Some(row) = rows.next().await {
165 let row = row?;
166 let timestamp = clock::Lamport {
167 replica_id: row.replica_id as u16,
168 value: row.lamport_timestamp as u32,
169 };
170 server_version.observe(timestamp);
171 if !client_version.observed(timestamp) {
172 operations.push(proto::Operation {
173 variant: Some(operation_from_storage(row, serialization_version)?),
174 })
175 }
176 }
177
178 results.push(RejoinedChannelBuffer {
179 old_connection_id,
180 buffer: proto::RejoinedChannelBuffer {
181 channel_id: client_buffer.channel_id,
182 version: version_to_wire(&server_version),
183 operations,
184 collaborators: collaborators
185 .into_iter()
186 .map(|collaborator| proto::Collaborator {
187 peer_id: Some(collaborator.connection().into()),
188 user_id: collaborator.user_id.to_proto(),
189 replica_id: collaborator.replica_id.0 as u32,
190 })
191 .collect(),
192 },
193 });
194 }
195
196 Ok(results)
197 })
198 .await
199 }
200
201 pub async fn clear_stale_channel_buffer_collaborators(
202 &self,
203 channel_id: ChannelId,
204 server_id: ServerId,
205 ) -> Result<RefreshedChannelBuffer> {
206 self.transaction(|tx| async move {
207 let collaborators = channel_buffer_collaborator::Entity::find()
208 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
209 .all(&*tx)
210 .await?;
211
212 let mut connection_ids = Vec::new();
213 let mut removed_collaborators = Vec::new();
214 let mut collaborator_ids_to_remove = Vec::new();
215 for collaborator in &collaborators {
216 if !collaborator.connection_lost && collaborator.connection_server_id == server_id {
217 connection_ids.push(collaborator.connection());
218 } else {
219 removed_collaborators.push(proto::RemoveChannelBufferCollaborator {
220 channel_id: channel_id.to_proto(),
221 peer_id: Some(collaborator.connection().into()),
222 });
223 collaborator_ids_to_remove.push(collaborator.id);
224 }
225 }
226
227 channel_buffer_collaborator::Entity::delete_many()
228 .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
229 .exec(&*tx)
230 .await?;
231
232 Ok(RefreshedChannelBuffer {
233 connection_ids,
234 removed_collaborators,
235 })
236 })
237 .await
238 }
239
240 pub async fn leave_channel_buffer(
241 &self,
242 channel_id: ChannelId,
243 connection: ConnectionId,
244 ) -> Result<Vec<ConnectionId>> {
245 self.transaction(|tx| async move {
246 self.leave_channel_buffer_internal(channel_id, connection, &*tx)
247 .await
248 })
249 .await
250 }
251
252 pub async fn leave_channel_buffers(
253 &self,
254 connection: ConnectionId,
255 ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
256 self.transaction(|tx| async move {
257 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
258 enum QueryChannelIds {
259 ChannelId,
260 }
261
262 let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
263 .select_only()
264 .column(channel_buffer_collaborator::Column::ChannelId)
265 .filter(Condition::all().add(
266 channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
267 ))
268 .into_values::<_, QueryChannelIds>()
269 .all(&*tx)
270 .await?;
271
272 let mut result = Vec::new();
273 for channel_id in channel_ids {
274 let collaborators = self
275 .leave_channel_buffer_internal(channel_id, connection, &*tx)
276 .await?;
277 result.push((channel_id, collaborators));
278 }
279
280 Ok(result)
281 })
282 .await
283 }
284
285 pub async fn leave_channel_buffer_internal(
286 &self,
287 channel_id: ChannelId,
288 connection: ConnectionId,
289 tx: &DatabaseTransaction,
290 ) -> Result<Vec<ConnectionId>> {
291 let result = channel_buffer_collaborator::Entity::delete_many()
292 .filter(
293 Condition::all()
294 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
295 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
296 .add(
297 channel_buffer_collaborator::Column::ConnectionServerId
298 .eq(connection.owner_id as i32),
299 ),
300 )
301 .exec(&*tx)
302 .await?;
303 if result.rows_affected == 0 {
304 Err(anyhow!("not a collaborator on this project"))?;
305 }
306
307 let mut connections = Vec::new();
308 let mut rows = channel_buffer_collaborator::Entity::find()
309 .filter(
310 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
311 )
312 .stream(&*tx)
313 .await?;
314 while let Some(row) = rows.next().await {
315 let row = row?;
316 connections.push(ConnectionId {
317 id: row.connection_id as u32,
318 owner_id: row.connection_server_id.0 as u32,
319 });
320 }
321
322 drop(rows);
323
324 if connections.is_empty() {
325 self.snapshot_channel_buffer(channel_id, &tx).await?;
326 }
327
328 Ok(connections)
329 }
330
331 pub async fn get_channel_buffer_collaborators(
332 &self,
333 channel_id: ChannelId,
334 ) -> Result<Vec<UserId>> {
335 self.transaction(|tx| async move {
336 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
337 enum QueryUserIds {
338 UserId,
339 }
340
341 let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
342 .select_only()
343 .column(channel_buffer_collaborator::Column::UserId)
344 .filter(
345 Condition::all()
346 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
347 )
348 .into_values::<_, QueryUserIds>()
349 .all(&*tx)
350 .await?;
351
352 Ok(users)
353 })
354 .await
355 }
356
357 pub async fn update_channel_buffer(
358 &self,
359 channel_id: ChannelId,
360 user: UserId,
361 operations: &[proto::Operation],
362 ) -> Result<Vec<ConnectionId>> {
363 self.transaction(move |tx| async move {
364 self.check_user_is_channel_member(channel_id, user, &*tx)
365 .await?;
366
367 let buffer = buffer::Entity::find()
368 .filter(buffer::Column::ChannelId.eq(channel_id))
369 .one(&*tx)
370 .await?
371 .ok_or_else(|| anyhow!("no such buffer"))?;
372
373 let serialization_version = self
374 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
375 .await?;
376
377 let operations = operations
378 .iter()
379 .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
380 .collect::<Vec<_>>();
381 if !operations.is_empty() {
382 buffer_operation::Entity::insert_many(operations)
383 .on_conflict(
384 OnConflict::columns([
385 buffer_operation::Column::BufferId,
386 buffer_operation::Column::Epoch,
387 buffer_operation::Column::LamportTimestamp,
388 buffer_operation::Column::ReplicaId,
389 ])
390 .do_nothing()
391 .to_owned(),
392 )
393 .exec(&*tx)
394 .await?;
395 }
396
397 let mut connections = Vec::new();
398 let mut rows = channel_buffer_collaborator::Entity::find()
399 .filter(
400 Condition::all()
401 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
402 )
403 .stream(&*tx)
404 .await?;
405 while let Some(row) = rows.next().await {
406 let row = row?;
407 connections.push(ConnectionId {
408 id: row.connection_id as u32,
409 owner_id: row.connection_server_id.0 as u32,
410 });
411 }
412
413 Ok(connections)
414 })
415 .await
416 }
417
418 async fn get_buffer_operation_serialization_version(
419 &self,
420 buffer_id: BufferId,
421 epoch: i32,
422 tx: &DatabaseTransaction,
423 ) -> Result<i32> {
424 Ok(buffer_snapshot::Entity::find()
425 .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
426 .filter(buffer_snapshot::Column::Epoch.eq(epoch))
427 .select_only()
428 .column(buffer_snapshot::Column::OperationSerializationVersion)
429 .into_values::<_, QueryOperationSerializationVersion>()
430 .one(&*tx)
431 .await?
432 .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
433 }
434
435 async fn get_channel_buffer(
436 &self,
437 channel_id: ChannelId,
438 tx: &DatabaseTransaction,
439 ) -> Result<buffer::Model> {
440 Ok(channel::Model {
441 id: channel_id,
442 ..Default::default()
443 }
444 .find_related(buffer::Entity)
445 .one(&*tx)
446 .await?
447 .ok_or_else(|| anyhow!("no such buffer"))?)
448 }
449
450 async fn get_buffer_state(
451 &self,
452 buffer: &buffer::Model,
453 tx: &DatabaseTransaction,
454 ) -> Result<(String, Vec<proto::Operation>)> {
455 let id = buffer.id;
456 let (base_text, version) = if buffer.epoch > 0 {
457 let snapshot = buffer_snapshot::Entity::find()
458 .filter(
459 buffer_snapshot::Column::BufferId
460 .eq(id)
461 .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
462 )
463 .one(&*tx)
464 .await?
465 .ok_or_else(|| anyhow!("no such snapshot"))?;
466
467 let version = snapshot.operation_serialization_version;
468 (snapshot.text, version)
469 } else {
470 (String::new(), storage::SERIALIZATION_VERSION)
471 };
472
473 let mut rows = buffer_operation::Entity::find()
474 .filter(
475 buffer_operation::Column::BufferId
476 .eq(id)
477 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
478 )
479 .stream(&*tx)
480 .await?;
481 let mut operations = Vec::new();
482 while let Some(row) = rows.next().await {
483 operations.push(proto::Operation {
484 variant: Some(operation_from_storage(row?, version)?),
485 })
486 }
487
488 Ok((base_text, operations))
489 }
490
491 async fn snapshot_channel_buffer(
492 &self,
493 channel_id: ChannelId,
494 tx: &DatabaseTransaction,
495 ) -> Result<()> {
496 let buffer = self.get_channel_buffer(channel_id, tx).await?;
497 let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
498 if operations.is_empty() {
499 return Ok(());
500 }
501
502 let mut text_buffer = text::Buffer::new(0, 0, base_text);
503 text_buffer
504 .apply_ops(operations.into_iter().filter_map(operation_from_wire))
505 .unwrap();
506
507 let base_text = text_buffer.text();
508 let epoch = buffer.epoch + 1;
509
510 buffer_snapshot::Model {
511 buffer_id: buffer.id,
512 epoch,
513 text: base_text,
514 operation_serialization_version: storage::SERIALIZATION_VERSION,
515 }
516 .into_active_model()
517 .insert(tx)
518 .await?;
519
520 buffer::ActiveModel {
521 id: ActiveValue::Unchanged(buffer.id),
522 epoch: ActiveValue::Set(epoch),
523 ..Default::default()
524 }
525 .save(tx)
526 .await?;
527
528 Ok(())
529 }
530}
531
532fn operation_to_storage(
533 operation: &proto::Operation,
534 buffer: &buffer::Model,
535 _format: i32,
536) -> Option<buffer_operation::ActiveModel> {
537 let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
538 proto::operation::Variant::Edit(operation) => (
539 operation.replica_id,
540 operation.lamport_timestamp,
541 storage::Operation {
542 version: version_to_storage(&operation.version),
543 is_undo: false,
544 edit_ranges: operation
545 .ranges
546 .iter()
547 .map(|range| storage::Range {
548 start: range.start,
549 end: range.end,
550 })
551 .collect(),
552 edit_texts: operation.new_text.clone(),
553 undo_counts: Vec::new(),
554 },
555 ),
556 proto::operation::Variant::Undo(operation) => (
557 operation.replica_id,
558 operation.lamport_timestamp,
559 storage::Operation {
560 version: version_to_storage(&operation.version),
561 is_undo: true,
562 edit_ranges: Vec::new(),
563 edit_texts: Vec::new(),
564 undo_counts: operation
565 .counts
566 .iter()
567 .map(|entry| storage::UndoCount {
568 replica_id: entry.replica_id,
569 lamport_timestamp: entry.lamport_timestamp,
570 count: entry.count,
571 })
572 .collect(),
573 },
574 ),
575 _ => None?,
576 };
577
578 Some(buffer_operation::ActiveModel {
579 buffer_id: ActiveValue::Set(buffer.id),
580 epoch: ActiveValue::Set(buffer.epoch),
581 replica_id: ActiveValue::Set(replica_id as i32),
582 lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
583 value: ActiveValue::Set(value.encode_to_vec()),
584 })
585}
586
587fn operation_from_storage(
588 row: buffer_operation::Model,
589 _format_version: i32,
590) -> Result<proto::operation::Variant, Error> {
591 let operation =
592 storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
593 let version = version_from_storage(&operation.version);
594 Ok(if operation.is_undo {
595 proto::operation::Variant::Undo(proto::operation::Undo {
596 replica_id: row.replica_id as u32,
597 lamport_timestamp: row.lamport_timestamp as u32,
598 version,
599 counts: operation
600 .undo_counts
601 .iter()
602 .map(|entry| proto::UndoCount {
603 replica_id: entry.replica_id,
604 lamport_timestamp: entry.lamport_timestamp,
605 count: entry.count,
606 })
607 .collect(),
608 })
609 } else {
610 proto::operation::Variant::Edit(proto::operation::Edit {
611 replica_id: row.replica_id as u32,
612 lamport_timestamp: row.lamport_timestamp as u32,
613 version,
614 ranges: operation
615 .edit_ranges
616 .into_iter()
617 .map(|range| proto::Range {
618 start: range.start,
619 end: range.end,
620 })
621 .collect(),
622 new_text: operation.edit_texts,
623 })
624 })
625}
626
627fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
628 version
629 .iter()
630 .map(|entry| storage::VectorClockEntry {
631 replica_id: entry.replica_id,
632 timestamp: entry.timestamp,
633 })
634 .collect()
635}
636
637fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
638 version
639 .iter()
640 .map(|entry| proto::VectorClockEntry {
641 replica_id: entry.replica_id,
642 timestamp: entry.timestamp,
643 })
644 .collect()
645}
646
647// This is currently a manual copy of the deserialization code in the client's langauge crate
648pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
649 match operation.variant? {
650 proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
651 timestamp: clock::Lamport {
652 replica_id: edit.replica_id as text::ReplicaId,
653 value: edit.lamport_timestamp,
654 },
655 version: version_from_wire(&edit.version),
656 ranges: edit
657 .ranges
658 .into_iter()
659 .map(|range| {
660 text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
661 })
662 .collect(),
663 new_text: edit.new_text.into_iter().map(Arc::from).collect(),
664 })),
665 proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
666 timestamp: clock::Lamport {
667 replica_id: undo.replica_id as text::ReplicaId,
668 value: undo.lamport_timestamp,
669 },
670 version: version_from_wire(&undo.version),
671 counts: undo
672 .counts
673 .into_iter()
674 .map(|c| {
675 (
676 clock::Lamport {
677 replica_id: c.replica_id as text::ReplicaId,
678 value: c.lamport_timestamp,
679 },
680 c.count,
681 )
682 })
683 .collect(),
684 })),
685 _ => None,
686 }
687}
688
689fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
690 let mut version = clock::Global::new();
691 for entry in message {
692 version.observe(clock::Lamport {
693 replica_id: entry.replica_id as text::ReplicaId,
694 value: entry.timestamp,
695 });
696 }
697 version
698}
699
700fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
701 let mut message = Vec::new();
702 for entry in version.iter() {
703 message.push(proto::VectorClockEntry {
704 replica_id: entry.replica_id as u32,
705 timestamp: entry.value,
706 });
707 }
708 message
709}
710
711#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
712enum QueryOperationSerializationVersion {
713 OperationSerializationVersion,
714}
715
716mod storage {
717 #![allow(non_snake_case)]
718 use prost::Message;
719 pub const SERIALIZATION_VERSION: i32 = 1;
720
721 #[derive(Message)]
722 pub struct Operation {
723 #[prost(message, repeated, tag = "2")]
724 pub version: Vec<VectorClockEntry>,
725 #[prost(bool, tag = "3")]
726 pub is_undo: bool,
727 #[prost(message, repeated, tag = "4")]
728 pub edit_ranges: Vec<Range>,
729 #[prost(string, repeated, tag = "5")]
730 pub edit_texts: Vec<String>,
731 #[prost(message, repeated, tag = "6")]
732 pub undo_counts: Vec<UndoCount>,
733 }
734
735 #[derive(Message)]
736 pub struct VectorClockEntry {
737 #[prost(uint32, tag = "1")]
738 pub replica_id: u32,
739 #[prost(uint32, tag = "2")]
740 pub timestamp: u32,
741 }
742
743 #[derive(Message)]
744 pub struct Range {
745 #[prost(uint64, tag = "1")]
746 pub start: u64,
747 #[prost(uint64, tag = "2")]
748 pub end: u64,
749 }
750
751 #[derive(Message)]
752 pub struct UndoCount {
753 #[prost(uint32, tag = "1")]
754 pub replica_id: u32,
755 #[prost(uint32, tag = "2")]
756 pub lamport_timestamp: u32,
757 #[prost(uint32, tag = "3")]
758 pub count: u32,
759 }
760}