1use super::*;
2use prost::Message;
3use text::{EditOperation, UndoOperation};
4
5pub struct LeftChannelBuffer {
6 pub channel_id: ChannelId,
7 pub collaborators: Vec<proto::Collaborator>,
8 pub connections: Vec<ConnectionId>,
9}
10
11impl Database {
12 /// Open a channel buffer. Returns the current contents, and adds you to the list of people
13 /// to notify on changes.
14 pub async fn join_channel_buffer(
15 &self,
16 channel_id: ChannelId,
17 user_id: UserId,
18 connection: ConnectionId,
19 ) -> Result<proto::JoinChannelBufferResponse> {
20 self.transaction(|tx| async move {
21 let channel = self.get_channel_internal(channel_id, &*tx).await?;
22 self.check_user_is_channel_participant(&channel, user_id, &tx)
23 .await?;
24
25 let buffer = channel::Model {
26 id: channel_id,
27 ..Default::default()
28 }
29 .find_related(buffer::Entity)
30 .one(&*tx)
31 .await?;
32
33 let buffer = if let Some(buffer) = buffer {
34 buffer
35 } else {
36 let buffer = buffer::ActiveModel {
37 channel_id: ActiveValue::Set(channel_id),
38 ..Default::default()
39 }
40 .insert(&*tx)
41 .await?;
42 buffer_snapshot::ActiveModel {
43 buffer_id: ActiveValue::Set(buffer.id),
44 epoch: ActiveValue::Set(0),
45 text: ActiveValue::Set(String::new()),
46 operation_serialization_version: ActiveValue::Set(
47 storage::SERIALIZATION_VERSION,
48 ),
49 }
50 .insert(&*tx)
51 .await?;
52 buffer
53 };
54
55 // Join the collaborators
56 let mut collaborators = channel_buffer_collaborator::Entity::find()
57 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
58 .all(&*tx)
59 .await?;
60 let replica_ids = collaborators
61 .iter()
62 .map(|c| c.replica_id)
63 .collect::<HashSet<_>>();
64 let mut replica_id = ReplicaId(0);
65 while replica_ids.contains(&replica_id) {
66 replica_id.0 += 1;
67 }
68 let collaborator = channel_buffer_collaborator::ActiveModel {
69 channel_id: ActiveValue::Set(channel_id),
70 connection_id: ActiveValue::Set(connection.id as i32),
71 connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
72 user_id: ActiveValue::Set(user_id),
73 replica_id: ActiveValue::Set(replica_id),
74 ..Default::default()
75 }
76 .insert(&*tx)
77 .await?;
78 collaborators.push(collaborator);
79
80 let (base_text, operations, max_operation) =
81 self.get_buffer_state(&buffer, &tx).await?;
82
83 // Save the last observed operation
84 if let Some(op) = max_operation {
85 observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
86 user_id: ActiveValue::Set(user_id),
87 buffer_id: ActiveValue::Set(buffer.id),
88 epoch: ActiveValue::Set(op.epoch),
89 lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
90 replica_id: ActiveValue::Set(op.replica_id),
91 })
92 .on_conflict(
93 OnConflict::columns([
94 observed_buffer_edits::Column::UserId,
95 observed_buffer_edits::Column::BufferId,
96 ])
97 .update_columns([
98 observed_buffer_edits::Column::Epoch,
99 observed_buffer_edits::Column::LamportTimestamp,
100 ])
101 .to_owned(),
102 )
103 .exec(&*tx)
104 .await?;
105 }
106
107 Ok(proto::JoinChannelBufferResponse {
108 buffer_id: buffer.id.to_proto(),
109 replica_id: replica_id.to_proto() as u32,
110 base_text,
111 operations,
112 epoch: buffer.epoch as u64,
113 collaborators: collaborators
114 .into_iter()
115 .map(|collaborator| proto::Collaborator {
116 peer_id: Some(collaborator.connection().into()),
117 user_id: collaborator.user_id.to_proto(),
118 replica_id: collaborator.replica_id.0 as u32,
119 })
120 .collect(),
121 })
122 })
123 .await
124 }
125
126 /// Rejoin a channel buffer (after a connection interruption)
127 pub async fn rejoin_channel_buffers(
128 &self,
129 buffers: &[proto::ChannelBufferVersion],
130 user_id: UserId,
131 connection_id: ConnectionId,
132 ) -> Result<Vec<RejoinedChannelBuffer>> {
133 self.transaction(|tx| async move {
134 let mut results = Vec::new();
135 for client_buffer in buffers {
136 let channel = self
137 .get_channel_internal(ChannelId::from_proto(client_buffer.channel_id), &*tx)
138 .await?;
139 if self
140 .check_user_is_channel_participant(&channel, user_id, &*tx)
141 .await
142 .is_err()
143 {
144 log::info!("user is not a member of channel");
145 continue;
146 }
147
148 let buffer = self.get_channel_buffer(channel.id, &*tx).await?;
149 let mut collaborators = channel_buffer_collaborator::Entity::find()
150 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel.id))
151 .all(&*tx)
152 .await?;
153
154 // If the buffer epoch hasn't changed since the client lost
155 // connection, then the client's buffer can be synchronized with
156 // the server's buffer.
157 if buffer.epoch as u64 != client_buffer.epoch {
158 log::info!("can't rejoin buffer, epoch has changed");
159 continue;
160 }
161
162 // Find the collaborator record for this user's previous lost
163 // connection. Update it with the new connection id.
164 let server_id = ServerId(connection_id.owner_id as i32);
165 let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
166 c.user_id == user_id
167 && (c.connection_lost || c.connection_server_id != server_id)
168 }) else {
169 log::info!("can't rejoin buffer, no previous collaborator found");
170 continue;
171 };
172 let old_connection_id = self_collaborator.connection();
173 *self_collaborator = channel_buffer_collaborator::ActiveModel {
174 id: ActiveValue::Unchanged(self_collaborator.id),
175 connection_id: ActiveValue::Set(connection_id.id as i32),
176 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
177 connection_lost: ActiveValue::Set(false),
178 ..Default::default()
179 }
180 .update(&*tx)
181 .await?;
182
183 let client_version = version_from_wire(&client_buffer.version);
184 let serialization_version = self
185 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
186 .await?;
187
188 let mut rows = buffer_operation::Entity::find()
189 .filter(
190 buffer_operation::Column::BufferId
191 .eq(buffer.id)
192 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
193 )
194 .stream(&*tx)
195 .await?;
196
197 // Find the server's version vector and any operations
198 // that the client has not seen.
199 let mut server_version = clock::Global::new();
200 let mut operations = Vec::new();
201 while let Some(row) = rows.next().await {
202 let row = row?;
203 let timestamp = clock::Lamport {
204 replica_id: row.replica_id as u16,
205 value: row.lamport_timestamp as u32,
206 };
207 server_version.observe(timestamp);
208 if !client_version.observed(timestamp) {
209 operations.push(proto::Operation {
210 variant: Some(operation_from_storage(row, serialization_version)?),
211 })
212 }
213 }
214
215 results.push(RejoinedChannelBuffer {
216 old_connection_id,
217 buffer: proto::RejoinedChannelBuffer {
218 channel_id: client_buffer.channel_id,
219 version: version_to_wire(&server_version),
220 operations,
221 collaborators: collaborators
222 .into_iter()
223 .map(|collaborator| proto::Collaborator {
224 peer_id: Some(collaborator.connection().into()),
225 user_id: collaborator.user_id.to_proto(),
226 replica_id: collaborator.replica_id.0 as u32,
227 })
228 .collect(),
229 },
230 });
231 }
232
233 Ok(results)
234 })
235 .await
236 }
237
238 /// Clear out any buffer collaborators who are no longer collaborating.
239 pub async fn clear_stale_channel_buffer_collaborators(
240 &self,
241 channel_id: ChannelId,
242 server_id: ServerId,
243 ) -> Result<RefreshedChannelBuffer> {
244 self.transaction(|tx| async move {
245 let db_collaborators = channel_buffer_collaborator::Entity::find()
246 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
247 .all(&*tx)
248 .await?;
249
250 let mut connection_ids = Vec::new();
251 let mut collaborators = Vec::new();
252 let mut collaborator_ids_to_remove = Vec::new();
253 for db_collaborator in &db_collaborators {
254 if !db_collaborator.connection_lost
255 && db_collaborator.connection_server_id == server_id
256 {
257 connection_ids.push(db_collaborator.connection());
258 collaborators.push(proto::Collaborator {
259 peer_id: Some(db_collaborator.connection().into()),
260 replica_id: db_collaborator.replica_id.0 as u32,
261 user_id: db_collaborator.user_id.to_proto(),
262 })
263 } else {
264 collaborator_ids_to_remove.push(db_collaborator.id);
265 }
266 }
267
268 channel_buffer_collaborator::Entity::delete_many()
269 .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
270 .exec(&*tx)
271 .await?;
272
273 Ok(RefreshedChannelBuffer {
274 connection_ids,
275 collaborators,
276 })
277 })
278 .await
279 }
280
281 /// Close the channel buffer, and stop receiving updates for it.
282 pub async fn leave_channel_buffer(
283 &self,
284 channel_id: ChannelId,
285 connection: ConnectionId,
286 ) -> Result<LeftChannelBuffer> {
287 self.transaction(|tx| async move {
288 self.leave_channel_buffer_internal(channel_id, connection, &*tx)
289 .await
290 })
291 .await
292 }
293
294 /// Close the channel buffer, and stop receiving updates for it.
295 pub async fn channel_buffer_connection_lost(
296 &self,
297 connection: ConnectionId,
298 tx: &DatabaseTransaction,
299 ) -> Result<()> {
300 channel_buffer_collaborator::Entity::update_many()
301 .filter(
302 Condition::all()
303 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
304 .add(
305 channel_buffer_collaborator::Column::ConnectionServerId
306 .eq(connection.owner_id as i32),
307 ),
308 )
309 .set(channel_buffer_collaborator::ActiveModel {
310 connection_lost: ActiveValue::set(true),
311 ..Default::default()
312 })
313 .exec(&*tx)
314 .await?;
315 Ok(())
316 }
317
318 /// Close all open channel buffers
319 pub async fn leave_channel_buffers(
320 &self,
321 connection: ConnectionId,
322 ) -> Result<Vec<LeftChannelBuffer>> {
323 self.transaction(|tx| async move {
324 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
325 enum QueryChannelIds {
326 ChannelId,
327 }
328
329 let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
330 .select_only()
331 .column(channel_buffer_collaborator::Column::ChannelId)
332 .filter(Condition::all().add(
333 channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
334 ))
335 .into_values::<_, QueryChannelIds>()
336 .all(&*tx)
337 .await?;
338
339 let mut result = Vec::new();
340 for channel_id in channel_ids {
341 let left_channel_buffer = self
342 .leave_channel_buffer_internal(channel_id, connection, &*tx)
343 .await?;
344 result.push(left_channel_buffer);
345 }
346
347 Ok(result)
348 })
349 .await
350 }
351
352 async fn leave_channel_buffer_internal(
353 &self,
354 channel_id: ChannelId,
355 connection: ConnectionId,
356 tx: &DatabaseTransaction,
357 ) -> Result<LeftChannelBuffer> {
358 let result = channel_buffer_collaborator::Entity::delete_many()
359 .filter(
360 Condition::all()
361 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
362 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
363 .add(
364 channel_buffer_collaborator::Column::ConnectionServerId
365 .eq(connection.owner_id as i32),
366 ),
367 )
368 .exec(&*tx)
369 .await?;
370 if result.rows_affected == 0 {
371 Err(anyhow!("not a collaborator on this project"))?;
372 }
373
374 let mut collaborators = Vec::new();
375 let mut connections = Vec::new();
376 let mut rows = channel_buffer_collaborator::Entity::find()
377 .filter(
378 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
379 )
380 .stream(&*tx)
381 .await?;
382 while let Some(row) = rows.next().await {
383 let row = row?;
384 let connection = row.connection();
385 connections.push(connection);
386 collaborators.push(proto::Collaborator {
387 peer_id: Some(connection.into()),
388 replica_id: row.replica_id.0 as u32,
389 user_id: row.user_id.to_proto(),
390 });
391 }
392
393 drop(rows);
394
395 if collaborators.is_empty() {
396 self.snapshot_channel_buffer(channel_id, &tx).await?;
397 }
398
399 Ok(LeftChannelBuffer {
400 channel_id,
401 collaborators,
402 connections,
403 })
404 }
405
406 pub async fn get_channel_buffer_collaborators(
407 &self,
408 channel_id: ChannelId,
409 ) -> Result<Vec<UserId>> {
410 self.transaction(|tx| async move {
411 self.get_channel_buffer_collaborators_internal(channel_id, &*tx)
412 .await
413 })
414 .await
415 }
416
417 async fn get_channel_buffer_collaborators_internal(
418 &self,
419 channel_id: ChannelId,
420 tx: &DatabaseTransaction,
421 ) -> Result<Vec<UserId>> {
422 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
423 enum QueryUserIds {
424 UserId,
425 }
426
427 let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
428 .select_only()
429 .column(channel_buffer_collaborator::Column::UserId)
430 .filter(
431 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
432 )
433 .into_values::<_, QueryUserIds>()
434 .all(&*tx)
435 .await?;
436
437 Ok(users)
438 }
439
440 pub async fn update_channel_buffer(
441 &self,
442 channel_id: ChannelId,
443 user: UserId,
444 operations: &[proto::Operation],
445 ) -> Result<(
446 Vec<ConnectionId>,
447 Vec<UserId>,
448 i32,
449 Vec<proto::VectorClockEntry>,
450 )> {
451 self.transaction(move |tx| async move {
452 let channel = self.get_channel_internal(channel_id, &*tx).await?;
453
454 let mut requires_write_permission = false;
455 for op in operations.iter() {
456 match op.variant {
457 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
458 Some(_) => requires_write_permission = true,
459 }
460 }
461 if requires_write_permission {
462 self.check_user_is_channel_member(&channel, user, &*tx)
463 .await?;
464 } else {
465 self.check_user_is_channel_participant(&channel, user, &*tx)
466 .await?;
467 }
468
469 let buffer = buffer::Entity::find()
470 .filter(buffer::Column::ChannelId.eq(channel_id))
471 .one(&*tx)
472 .await?
473 .ok_or_else(|| anyhow!("no such buffer"))?;
474
475 let serialization_version = self
476 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
477 .await?;
478
479 let operations = operations
480 .iter()
481 .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
482 .collect::<Vec<_>>();
483
484 let mut channel_members;
485 let max_version;
486
487 if !operations.is_empty() {
488 let max_operation = operations
489 .iter()
490 .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
491 .unwrap();
492
493 max_version = vec![proto::VectorClockEntry {
494 replica_id: *max_operation.replica_id.as_ref() as u32,
495 timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
496 }];
497
498 // get current channel participants and save the max operation above
499 self.save_max_operation(
500 user,
501 buffer.id,
502 buffer.epoch,
503 *max_operation.replica_id.as_ref(),
504 *max_operation.lamport_timestamp.as_ref(),
505 &*tx,
506 )
507 .await?;
508
509 channel_members = self.get_channel_participants(&channel, &*tx).await?;
510 let collaborators = self
511 .get_channel_buffer_collaborators_internal(channel_id, &*tx)
512 .await?;
513 channel_members.retain(|member| !collaborators.contains(member));
514
515 buffer_operation::Entity::insert_many(operations)
516 .on_conflict(
517 OnConflict::columns([
518 buffer_operation::Column::BufferId,
519 buffer_operation::Column::Epoch,
520 buffer_operation::Column::LamportTimestamp,
521 buffer_operation::Column::ReplicaId,
522 ])
523 .do_nothing()
524 .to_owned(),
525 )
526 .exec(&*tx)
527 .await?;
528 } else {
529 channel_members = Vec::new();
530 max_version = Vec::new();
531 }
532
533 let mut connections = Vec::new();
534 let mut rows = channel_buffer_collaborator::Entity::find()
535 .filter(
536 Condition::all()
537 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
538 )
539 .stream(&*tx)
540 .await?;
541 while let Some(row) = rows.next().await {
542 let row = row?;
543 connections.push(ConnectionId {
544 id: row.connection_id as u32,
545 owner_id: row.connection_server_id.0 as u32,
546 });
547 }
548
549 Ok((connections, channel_members, buffer.epoch, max_version))
550 })
551 .await
552 }
553
554 async fn save_max_operation(
555 &self,
556 user_id: UserId,
557 buffer_id: BufferId,
558 epoch: i32,
559 replica_id: i32,
560 lamport_timestamp: i32,
561 tx: &DatabaseTransaction,
562 ) -> Result<()> {
563 use observed_buffer_edits::Column;
564
565 observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
566 user_id: ActiveValue::Set(user_id),
567 buffer_id: ActiveValue::Set(buffer_id),
568 epoch: ActiveValue::Set(epoch),
569 replica_id: ActiveValue::Set(replica_id),
570 lamport_timestamp: ActiveValue::Set(lamport_timestamp),
571 })
572 .on_conflict(
573 OnConflict::columns([Column::UserId, Column::BufferId])
574 .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
575 .action_cond_where(
576 Condition::any().add(Column::Epoch.lt(epoch)).add(
577 Condition::all().add(Column::Epoch.eq(epoch)).add(
578 Condition::any()
579 .add(Column::LamportTimestamp.lt(lamport_timestamp))
580 .add(
581 Column::LamportTimestamp
582 .eq(lamport_timestamp)
583 .and(Column::ReplicaId.lt(replica_id)),
584 ),
585 ),
586 ),
587 )
588 .to_owned(),
589 )
590 .exec_without_returning(tx)
591 .await?;
592
593 Ok(())
594 }
595
596 async fn get_buffer_operation_serialization_version(
597 &self,
598 buffer_id: BufferId,
599 epoch: i32,
600 tx: &DatabaseTransaction,
601 ) -> Result<i32> {
602 Ok(buffer_snapshot::Entity::find()
603 .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
604 .filter(buffer_snapshot::Column::Epoch.eq(epoch))
605 .select_only()
606 .column(buffer_snapshot::Column::OperationSerializationVersion)
607 .into_values::<_, QueryOperationSerializationVersion>()
608 .one(&*tx)
609 .await?
610 .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
611 }
612
613 pub async fn get_channel_buffer(
614 &self,
615 channel_id: ChannelId,
616 tx: &DatabaseTransaction,
617 ) -> Result<buffer::Model> {
618 Ok(channel::Model {
619 id: channel_id,
620 ..Default::default()
621 }
622 .find_related(buffer::Entity)
623 .one(&*tx)
624 .await?
625 .ok_or_else(|| anyhow!("no such buffer"))?)
626 }
627
628 async fn get_buffer_state(
629 &self,
630 buffer: &buffer::Model,
631 tx: &DatabaseTransaction,
632 ) -> Result<(
633 String,
634 Vec<proto::Operation>,
635 Option<buffer_operation::Model>,
636 )> {
637 let id = buffer.id;
638 let (base_text, version) = if buffer.epoch > 0 {
639 let snapshot = buffer_snapshot::Entity::find()
640 .filter(
641 buffer_snapshot::Column::BufferId
642 .eq(id)
643 .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
644 )
645 .one(&*tx)
646 .await?
647 .ok_or_else(|| anyhow!("no such snapshot"))?;
648
649 let version = snapshot.operation_serialization_version;
650 (snapshot.text, version)
651 } else {
652 (String::new(), storage::SERIALIZATION_VERSION)
653 };
654
655 let mut rows = buffer_operation::Entity::find()
656 .filter(
657 buffer_operation::Column::BufferId
658 .eq(id)
659 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
660 )
661 .order_by_asc(buffer_operation::Column::LamportTimestamp)
662 .order_by_asc(buffer_operation::Column::ReplicaId)
663 .stream(&*tx)
664 .await?;
665
666 let mut operations = Vec::new();
667 let mut last_row = None;
668 while let Some(row) = rows.next().await {
669 let row = row?;
670 last_row = Some(buffer_operation::Model {
671 buffer_id: row.buffer_id,
672 epoch: row.epoch,
673 lamport_timestamp: row.lamport_timestamp,
674 replica_id: row.lamport_timestamp,
675 value: Default::default(),
676 });
677 operations.push(proto::Operation {
678 variant: Some(operation_from_storage(row, version)?),
679 });
680 }
681
682 Ok((base_text, operations, last_row))
683 }
684
685 async fn snapshot_channel_buffer(
686 &self,
687 channel_id: ChannelId,
688 tx: &DatabaseTransaction,
689 ) -> Result<()> {
690 let buffer = self.get_channel_buffer(channel_id, tx).await?;
691 let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
692 if operations.is_empty() {
693 return Ok(());
694 }
695
696 let mut text_buffer = text::Buffer::new(0, text::BufferId::new(1).unwrap(), base_text);
697 text_buffer
698 .apply_ops(operations.into_iter().filter_map(operation_from_wire))
699 .unwrap();
700
701 let base_text = text_buffer.text();
702 let epoch = buffer.epoch + 1;
703
704 buffer_snapshot::Model {
705 buffer_id: buffer.id,
706 epoch,
707 text: base_text,
708 operation_serialization_version: storage::SERIALIZATION_VERSION,
709 }
710 .into_active_model()
711 .insert(tx)
712 .await?;
713
714 buffer::ActiveModel {
715 id: ActiveValue::Unchanged(buffer.id),
716 epoch: ActiveValue::Set(epoch),
717 ..Default::default()
718 }
719 .save(tx)
720 .await?;
721
722 Ok(())
723 }
724
725 pub async fn observe_buffer_version(
726 &self,
727 buffer_id: BufferId,
728 user_id: UserId,
729 epoch: i32,
730 version: &[proto::VectorClockEntry],
731 ) -> Result<()> {
732 self.transaction(|tx| async move {
733 // For now, combine concurrent operations.
734 let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
735 return Ok(());
736 };
737 self.save_max_operation(
738 user_id,
739 buffer_id,
740 epoch,
741 component.replica_id as i32,
742 component.timestamp as i32,
743 &*tx,
744 )
745 .await?;
746 Ok(())
747 })
748 .await
749 }
750
751 pub async fn latest_channel_buffer_changes(
752 &self,
753 channel_ids: &[ChannelId],
754 tx: &DatabaseTransaction,
755 ) -> Result<Vec<proto::ChannelBufferVersion>> {
756 let mut channel_ids_by_buffer_id = HashMap::default();
757 let mut rows = buffer::Entity::find()
758 .filter(buffer::Column::ChannelId.is_in(channel_ids.iter().copied()))
759 .stream(&*tx)
760 .await?;
761 while let Some(row) = rows.next().await {
762 let row = row?;
763 channel_ids_by_buffer_id.insert(row.id, row.channel_id);
764 }
765 drop(rows);
766
767 let latest_operations = self
768 .get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
769 .await?;
770
771 Ok(latest_operations
772 .iter()
773 .flat_map(|op| {
774 Some(proto::ChannelBufferVersion {
775 channel_id: channel_ids_by_buffer_id.get(&op.buffer_id)?.to_proto(),
776 epoch: op.epoch as u64,
777 version: vec![proto::VectorClockEntry {
778 replica_id: op.replica_id as u32,
779 timestamp: op.lamport_timestamp as u32,
780 }],
781 })
782 })
783 .collect())
784 }
785
786 /// Returns the latest operations for the buffers with the specified IDs.
787 pub async fn get_latest_operations_for_buffers(
788 &self,
789 buffer_ids: impl IntoIterator<Item = BufferId>,
790 tx: &DatabaseTransaction,
791 ) -> Result<Vec<buffer_operation::Model>> {
792 let mut values = String::new();
793 for id in buffer_ids {
794 if !values.is_empty() {
795 values.push_str(", ");
796 }
797 write!(&mut values, "({})", id).unwrap();
798 }
799
800 if values.is_empty() {
801 return Ok(Vec::default());
802 }
803
804 let sql = format!(
805 r#"
806 SELECT
807 *
808 FROM
809 (
810 SELECT
811 *,
812 row_number() OVER (
813 PARTITION BY buffer_id
814 ORDER BY
815 epoch DESC,
816 lamport_timestamp DESC,
817 replica_id DESC
818 ) as row_number
819 FROM buffer_operations
820 WHERE
821 buffer_id in ({values})
822 ) AS last_operations
823 WHERE
824 row_number = 1
825 "#,
826 );
827
828 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
829 Ok(buffer_operation::Entity::find()
830 .from_raw_sql(stmt)
831 .all(&*tx)
832 .await?)
833 }
834}
835
836fn operation_to_storage(
837 operation: &proto::Operation,
838 buffer: &buffer::Model,
839 _format: i32,
840) -> Option<buffer_operation::ActiveModel> {
841 let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
842 proto::operation::Variant::Edit(operation) => (
843 operation.replica_id,
844 operation.lamport_timestamp,
845 storage::Operation {
846 version: version_to_storage(&operation.version),
847 is_undo: false,
848 edit_ranges: operation
849 .ranges
850 .iter()
851 .map(|range| storage::Range {
852 start: range.start,
853 end: range.end,
854 })
855 .collect(),
856 edit_texts: operation.new_text.clone(),
857 undo_counts: Vec::new(),
858 },
859 ),
860 proto::operation::Variant::Undo(operation) => (
861 operation.replica_id,
862 operation.lamport_timestamp,
863 storage::Operation {
864 version: version_to_storage(&operation.version),
865 is_undo: true,
866 edit_ranges: Vec::new(),
867 edit_texts: Vec::new(),
868 undo_counts: operation
869 .counts
870 .iter()
871 .map(|entry| storage::UndoCount {
872 replica_id: entry.replica_id,
873 lamport_timestamp: entry.lamport_timestamp,
874 count: entry.count,
875 })
876 .collect(),
877 },
878 ),
879 _ => None?,
880 };
881
882 Some(buffer_operation::ActiveModel {
883 buffer_id: ActiveValue::Set(buffer.id),
884 epoch: ActiveValue::Set(buffer.epoch),
885 replica_id: ActiveValue::Set(replica_id as i32),
886 lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
887 value: ActiveValue::Set(value.encode_to_vec()),
888 })
889}
890
891fn operation_from_storage(
892 row: buffer_operation::Model,
893 _format_version: i32,
894) -> Result<proto::operation::Variant, Error> {
895 let operation =
896 storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
897 let version = version_from_storage(&operation.version);
898 Ok(if operation.is_undo {
899 proto::operation::Variant::Undo(proto::operation::Undo {
900 replica_id: row.replica_id as u32,
901 lamport_timestamp: row.lamport_timestamp as u32,
902 version,
903 counts: operation
904 .undo_counts
905 .iter()
906 .map(|entry| proto::UndoCount {
907 replica_id: entry.replica_id,
908 lamport_timestamp: entry.lamport_timestamp,
909 count: entry.count,
910 })
911 .collect(),
912 })
913 } else {
914 proto::operation::Variant::Edit(proto::operation::Edit {
915 replica_id: row.replica_id as u32,
916 lamport_timestamp: row.lamport_timestamp as u32,
917 version,
918 ranges: operation
919 .edit_ranges
920 .into_iter()
921 .map(|range| proto::Range {
922 start: range.start,
923 end: range.end,
924 })
925 .collect(),
926 new_text: operation.edit_texts,
927 })
928 })
929}
930
931fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
932 version
933 .iter()
934 .map(|entry| storage::VectorClockEntry {
935 replica_id: entry.replica_id,
936 timestamp: entry.timestamp,
937 })
938 .collect()
939}
940
941fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
942 version
943 .iter()
944 .map(|entry| proto::VectorClockEntry {
945 replica_id: entry.replica_id,
946 timestamp: entry.timestamp,
947 })
948 .collect()
949}
950
951// This is currently a manual copy of the deserialization code in the client's language crate
952pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
953 match operation.variant? {
954 proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
955 timestamp: clock::Lamport {
956 replica_id: edit.replica_id as text::ReplicaId,
957 value: edit.lamport_timestamp,
958 },
959 version: version_from_wire(&edit.version),
960 ranges: edit
961 .ranges
962 .into_iter()
963 .map(|range| {
964 text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
965 })
966 .collect(),
967 new_text: edit.new_text.into_iter().map(Arc::from).collect(),
968 })),
969 proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
970 timestamp: clock::Lamport {
971 replica_id: undo.replica_id as text::ReplicaId,
972 value: undo.lamport_timestamp,
973 },
974 version: version_from_wire(&undo.version),
975 counts: undo
976 .counts
977 .into_iter()
978 .map(|c| {
979 (
980 clock::Lamport {
981 replica_id: c.replica_id as text::ReplicaId,
982 value: c.lamport_timestamp,
983 },
984 c.count,
985 )
986 })
987 .collect(),
988 })),
989 _ => None,
990 }
991}
992
993fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
994 let mut version = clock::Global::new();
995 for entry in message {
996 version.observe(clock::Lamport {
997 replica_id: entry.replica_id as text::ReplicaId,
998 value: entry.timestamp,
999 });
1000 }
1001 version
1002}
1003
1004fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
1005 let mut message = Vec::new();
1006 for entry in version.iter() {
1007 message.push(proto::VectorClockEntry {
1008 replica_id: entry.replica_id as u32,
1009 timestamp: entry.value,
1010 });
1011 }
1012 message
1013}
1014
1015#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
1016enum QueryOperationSerializationVersion {
1017 OperationSerializationVersion,
1018}
1019
1020mod storage {
1021 #![allow(non_snake_case)]
1022 use prost::Message;
1023 pub const SERIALIZATION_VERSION: i32 = 1;
1024
1025 #[derive(Message)]
1026 pub struct Operation {
1027 #[prost(message, repeated, tag = "2")]
1028 pub version: Vec<VectorClockEntry>,
1029 #[prost(bool, tag = "3")]
1030 pub is_undo: bool,
1031 #[prost(message, repeated, tag = "4")]
1032 pub edit_ranges: Vec<Range>,
1033 #[prost(string, repeated, tag = "5")]
1034 pub edit_texts: Vec<String>,
1035 #[prost(message, repeated, tag = "6")]
1036 pub undo_counts: Vec<UndoCount>,
1037 }
1038
1039 #[derive(Message)]
1040 pub struct VectorClockEntry {
1041 #[prost(uint32, tag = "1")]
1042 pub replica_id: u32,
1043 #[prost(uint32, tag = "2")]
1044 pub timestamp: u32,
1045 }
1046
1047 #[derive(Message)]
1048 pub struct Range {
1049 #[prost(uint64, tag = "1")]
1050 pub start: u64,
1051 #[prost(uint64, tag = "2")]
1052 pub end: u64,
1053 }
1054
1055 #[derive(Message)]
1056 pub struct UndoCount {
1057 #[prost(uint32, tag = "1")]
1058 pub replica_id: u32,
1059 #[prost(uint32, tag = "2")]
1060 pub lamport_timestamp: u32,
1061 #[prost(uint32, tag = "3")]
1062 pub count: u32,
1063 }
1064}