1use super::*;
2use prost::Message;
3use text::{EditOperation, UndoOperation};
4
5pub struct LeftChannelBuffer {
6 pub channel_id: ChannelId,
7 pub collaborators: Vec<proto::Collaborator>,
8 pub connections: Vec<ConnectionId>,
9}
10
11impl Database {
12 /// Open a channel buffer. Returns the current contents, and adds you to the list of people
13 /// to notify on changes.
14 pub async fn join_channel_buffer(
15 &self,
16 channel_id: ChannelId,
17 user_id: UserId,
18 connection: ConnectionId,
19 ) -> Result<proto::JoinChannelBufferResponse> {
20 self.transaction(|tx| async move {
21 let channel = self.get_channel_internal(channel_id, &tx).await?;
22 self.check_user_is_channel_participant(&channel, user_id, &tx)
23 .await?;
24
25 let buffer = channel::Model {
26 id: channel_id,
27 ..Default::default()
28 }
29 .find_related(buffer::Entity)
30 .one(&*tx)
31 .await?;
32
33 let buffer = if let Some(buffer) = buffer {
34 buffer
35 } else {
36 let buffer = buffer::ActiveModel {
37 channel_id: ActiveValue::Set(channel_id),
38 ..Default::default()
39 }
40 .insert(&*tx)
41 .await?;
42 buffer_snapshot::ActiveModel {
43 buffer_id: ActiveValue::Set(buffer.id),
44 epoch: ActiveValue::Set(0),
45 text: ActiveValue::Set(String::new()),
46 operation_serialization_version: ActiveValue::Set(
47 storage::SERIALIZATION_VERSION,
48 ),
49 }
50 .insert(&*tx)
51 .await?;
52 buffer
53 };
54
55 // Join the collaborators
56 let mut collaborators = channel_buffer_collaborator::Entity::find()
57 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
58 .all(&*tx)
59 .await?;
60 let replica_ids = collaborators
61 .iter()
62 .map(|c| c.replica_id)
63 .collect::<HashSet<_>>();
64 let mut replica_id = ReplicaId(0);
65 while replica_ids.contains(&replica_id) {
66 replica_id.0 += 1;
67 }
68 let collaborator = channel_buffer_collaborator::ActiveModel {
69 channel_id: ActiveValue::Set(channel_id),
70 connection_id: ActiveValue::Set(connection.id as i32),
71 connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
72 user_id: ActiveValue::Set(user_id),
73 replica_id: ActiveValue::Set(replica_id),
74 ..Default::default()
75 }
76 .insert(&*tx)
77 .await?;
78 collaborators.push(collaborator);
79
80 let (base_text, operations, max_operation) =
81 self.get_buffer_state(&buffer, &tx).await?;
82
83 // Save the last observed operation
84 if let Some(op) = max_operation {
85 observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
86 user_id: ActiveValue::Set(user_id),
87 buffer_id: ActiveValue::Set(buffer.id),
88 epoch: ActiveValue::Set(op.epoch),
89 lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
90 replica_id: ActiveValue::Set(op.replica_id),
91 })
92 .on_conflict(
93 OnConflict::columns([
94 observed_buffer_edits::Column::UserId,
95 observed_buffer_edits::Column::BufferId,
96 ])
97 .update_columns([
98 observed_buffer_edits::Column::Epoch,
99 observed_buffer_edits::Column::LamportTimestamp,
100 ])
101 .to_owned(),
102 )
103 .exec(&*tx)
104 .await?;
105 }
106
107 Ok(proto::JoinChannelBufferResponse {
108 buffer_id: buffer.id.to_proto(),
109 replica_id: replica_id.to_proto() as u32,
110 base_text,
111 operations,
112 epoch: buffer.epoch as u64,
113 collaborators: collaborators
114 .into_iter()
115 .map(|collaborator| proto::Collaborator {
116 peer_id: Some(collaborator.connection().into()),
117 user_id: collaborator.user_id.to_proto(),
118 replica_id: collaborator.replica_id.0 as u32,
119 })
120 .collect(),
121 })
122 })
123 .await
124 }
125
126 /// Rejoin a channel buffer (after a connection interruption)
127 pub async fn rejoin_channel_buffers(
128 &self,
129 buffers: &[proto::ChannelBufferVersion],
130 user_id: UserId,
131 connection_id: ConnectionId,
132 ) -> Result<Vec<RejoinedChannelBuffer>> {
133 self.transaction(|tx| async move {
134 let mut results = Vec::new();
135 for client_buffer in buffers {
136 let channel = self
137 .get_channel_internal(ChannelId::from_proto(client_buffer.channel_id), &tx)
138 .await?;
139 if self
140 .check_user_is_channel_participant(&channel, user_id, &tx)
141 .await
142 .is_err()
143 {
144 log::info!("user is not a member of channel");
145 continue;
146 }
147
148 let buffer = self.get_channel_buffer(channel.id, &tx).await?;
149 let mut collaborators = channel_buffer_collaborator::Entity::find()
150 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel.id))
151 .all(&*tx)
152 .await?;
153
154 // If the buffer epoch hasn't changed since the client lost
155 // connection, then the client's buffer can be synchronized with
156 // the server's buffer.
157 if buffer.epoch as u64 != client_buffer.epoch {
158 log::info!("can't rejoin buffer, epoch has changed");
159 continue;
160 }
161
162 // Find the collaborator record for this user's previous lost
163 // connection. Update it with the new connection id.
164 let Some(self_collaborator) =
165 collaborators.iter_mut().find(|c| c.user_id == user_id)
166 else {
167 log::info!("can't rejoin buffer, no previous collaborator found");
168 continue;
169 };
170 let old_connection_id = self_collaborator.connection();
171 *self_collaborator = channel_buffer_collaborator::ActiveModel {
172 id: ActiveValue::Unchanged(self_collaborator.id),
173 connection_id: ActiveValue::Set(connection_id.id as i32),
174 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
175 connection_lost: ActiveValue::Set(false),
176 ..Default::default()
177 }
178 .update(&*tx)
179 .await?;
180
181 let client_version = version_from_wire(&client_buffer.version);
182 let serialization_version = self
183 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &tx)
184 .await?;
185
186 let mut rows = buffer_operation::Entity::find()
187 .filter(
188 buffer_operation::Column::BufferId
189 .eq(buffer.id)
190 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
191 )
192 .stream(&*tx)
193 .await?;
194
195 // Find the server's version vector and any operations
196 // that the client has not seen.
197 let mut server_version = clock::Global::new();
198 let mut operations = Vec::new();
199 while let Some(row) = rows.next().await {
200 let row = row?;
201 let timestamp = clock::Lamport {
202 replica_id: row.replica_id as u16,
203 value: row.lamport_timestamp as u32,
204 };
205 server_version.observe(timestamp);
206 if !client_version.observed(timestamp) {
207 operations.push(proto::Operation {
208 variant: Some(operation_from_storage(row, serialization_version)?),
209 })
210 }
211 }
212
213 results.push(RejoinedChannelBuffer {
214 old_connection_id,
215 buffer: proto::RejoinedChannelBuffer {
216 channel_id: client_buffer.channel_id,
217 version: version_to_wire(&server_version),
218 operations,
219 collaborators: collaborators
220 .into_iter()
221 .map(|collaborator| proto::Collaborator {
222 peer_id: Some(collaborator.connection().into()),
223 user_id: collaborator.user_id.to_proto(),
224 replica_id: collaborator.replica_id.0 as u32,
225 })
226 .collect(),
227 },
228 });
229 }
230
231 Ok(results)
232 })
233 .await
234 }
235
236 /// Clear out any buffer collaborators who are no longer collaborating.
237 pub async fn clear_stale_channel_buffer_collaborators(
238 &self,
239 channel_id: ChannelId,
240 server_id: ServerId,
241 ) -> Result<RefreshedChannelBuffer> {
242 self.transaction(|tx| async move {
243 let db_collaborators = channel_buffer_collaborator::Entity::find()
244 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
245 .all(&*tx)
246 .await?;
247
248 let mut connection_ids = Vec::new();
249 let mut collaborators = Vec::new();
250 let mut collaborator_ids_to_remove = Vec::new();
251 for db_collaborator in &db_collaborators {
252 if !db_collaborator.connection_lost
253 && db_collaborator.connection_server_id == server_id
254 {
255 connection_ids.push(db_collaborator.connection());
256 collaborators.push(proto::Collaborator {
257 peer_id: Some(db_collaborator.connection().into()),
258 replica_id: db_collaborator.replica_id.0 as u32,
259 user_id: db_collaborator.user_id.to_proto(),
260 })
261 } else {
262 collaborator_ids_to_remove.push(db_collaborator.id);
263 }
264 }
265
266 channel_buffer_collaborator::Entity::delete_many()
267 .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
268 .exec(&*tx)
269 .await?;
270
271 Ok(RefreshedChannelBuffer {
272 connection_ids,
273 collaborators,
274 })
275 })
276 .await
277 }
278
279 /// Close the channel buffer, and stop receiving updates for it.
280 pub async fn leave_channel_buffer(
281 &self,
282 channel_id: ChannelId,
283 connection: ConnectionId,
284 ) -> Result<LeftChannelBuffer> {
285 self.transaction(|tx| async move {
286 self.leave_channel_buffer_internal(channel_id, connection, &tx)
287 .await
288 })
289 .await
290 }
291
292 /// Close the channel buffer, and stop receiving updates for it.
293 pub async fn channel_buffer_connection_lost(
294 &self,
295 connection: ConnectionId,
296 tx: &DatabaseTransaction,
297 ) -> Result<()> {
298 channel_buffer_collaborator::Entity::update_many()
299 .filter(
300 Condition::all()
301 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
302 .add(
303 channel_buffer_collaborator::Column::ConnectionServerId
304 .eq(connection.owner_id as i32),
305 ),
306 )
307 .set(channel_buffer_collaborator::ActiveModel {
308 connection_lost: ActiveValue::set(true),
309 ..Default::default()
310 })
311 .exec(tx)
312 .await?;
313 Ok(())
314 }
315
316 /// Close all open channel buffers
317 pub async fn leave_channel_buffers(
318 &self,
319 connection: ConnectionId,
320 ) -> Result<Vec<LeftChannelBuffer>> {
321 self.transaction(|tx| async move {
322 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
323 enum QueryChannelIds {
324 ChannelId,
325 }
326
327 let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
328 .select_only()
329 .column(channel_buffer_collaborator::Column::ChannelId)
330 .filter(Condition::all().add(
331 channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
332 ))
333 .into_values::<_, QueryChannelIds>()
334 .all(&*tx)
335 .await?;
336
337 let mut result = Vec::new();
338 for channel_id in channel_ids {
339 let left_channel_buffer = self
340 .leave_channel_buffer_internal(channel_id, connection, &tx)
341 .await?;
342 result.push(left_channel_buffer);
343 }
344
345 Ok(result)
346 })
347 .await
348 }
349
350 async fn leave_channel_buffer_internal(
351 &self,
352 channel_id: ChannelId,
353 connection: ConnectionId,
354 tx: &DatabaseTransaction,
355 ) -> Result<LeftChannelBuffer> {
356 let result = channel_buffer_collaborator::Entity::delete_many()
357 .filter(
358 Condition::all()
359 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
360 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
361 .add(
362 channel_buffer_collaborator::Column::ConnectionServerId
363 .eq(connection.owner_id as i32),
364 ),
365 )
366 .exec(tx)
367 .await?;
368 if result.rows_affected == 0 {
369 Err(anyhow!("not a collaborator on this project"))?;
370 }
371
372 let mut collaborators = Vec::new();
373 let mut connections = Vec::new();
374 let mut rows = channel_buffer_collaborator::Entity::find()
375 .filter(
376 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
377 )
378 .stream(tx)
379 .await?;
380 while let Some(row) = rows.next().await {
381 let row = row?;
382 let connection = row.connection();
383 connections.push(connection);
384 collaborators.push(proto::Collaborator {
385 peer_id: Some(connection.into()),
386 replica_id: row.replica_id.0 as u32,
387 user_id: row.user_id.to_proto(),
388 });
389 }
390
391 drop(rows);
392
393 if collaborators.is_empty() {
394 self.snapshot_channel_buffer(channel_id, &tx).await?;
395 }
396
397 Ok(LeftChannelBuffer {
398 channel_id,
399 collaborators,
400 connections,
401 })
402 }
403
404 pub async fn get_channel_buffer_collaborators(
405 &self,
406 channel_id: ChannelId,
407 ) -> Result<Vec<UserId>> {
408 self.transaction(|tx| async move {
409 self.get_channel_buffer_collaborators_internal(channel_id, &tx)
410 .await
411 })
412 .await
413 }
414
415 async fn get_channel_buffer_collaborators_internal(
416 &self,
417 channel_id: ChannelId,
418 tx: &DatabaseTransaction,
419 ) -> Result<Vec<UserId>> {
420 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
421 enum QueryUserIds {
422 UserId,
423 }
424
425 let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
426 .select_only()
427 .column(channel_buffer_collaborator::Column::UserId)
428 .filter(
429 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
430 )
431 .into_values::<_, QueryUserIds>()
432 .all(tx)
433 .await?;
434
435 Ok(users)
436 }
437
438 pub async fn update_channel_buffer(
439 &self,
440 channel_id: ChannelId,
441 user: UserId,
442 operations: &[proto::Operation],
443 ) -> Result<(
444 Vec<ConnectionId>,
445 Vec<UserId>,
446 i32,
447 Vec<proto::VectorClockEntry>,
448 )> {
449 self.transaction(move |tx| async move {
450 let channel = self.get_channel_internal(channel_id, &tx).await?;
451
452 let mut requires_write_permission = false;
453 for op in operations.iter() {
454 match op.variant {
455 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
456 Some(_) => requires_write_permission = true,
457 }
458 }
459 if requires_write_permission {
460 self.check_user_is_channel_member(&channel, user, &tx)
461 .await?;
462 } else {
463 self.check_user_is_channel_participant(&channel, user, &tx)
464 .await?;
465 }
466
467 let buffer = buffer::Entity::find()
468 .filter(buffer::Column::ChannelId.eq(channel_id))
469 .one(&*tx)
470 .await?
471 .ok_or_else(|| anyhow!("no such buffer"))?;
472
473 let serialization_version = self
474 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &tx)
475 .await?;
476
477 let operations = operations
478 .iter()
479 .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
480 .collect::<Vec<_>>();
481
482 let mut channel_members;
483 let max_version;
484
485 if !operations.is_empty() {
486 let max_operation = operations
487 .iter()
488 .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
489 .unwrap();
490
491 max_version = vec![proto::VectorClockEntry {
492 replica_id: *max_operation.replica_id.as_ref() as u32,
493 timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
494 }];
495
496 // get current channel participants and save the max operation above
497 self.save_max_operation(
498 user,
499 buffer.id,
500 buffer.epoch,
501 *max_operation.replica_id.as_ref(),
502 *max_operation.lamport_timestamp.as_ref(),
503 &tx,
504 )
505 .await?;
506
507 channel_members = self.get_channel_participants(&channel, &tx).await?;
508 let collaborators = self
509 .get_channel_buffer_collaborators_internal(channel_id, &tx)
510 .await?;
511 channel_members.retain(|member| !collaborators.contains(member));
512
513 buffer_operation::Entity::insert_many(operations)
514 .on_conflict(
515 OnConflict::columns([
516 buffer_operation::Column::BufferId,
517 buffer_operation::Column::Epoch,
518 buffer_operation::Column::LamportTimestamp,
519 buffer_operation::Column::ReplicaId,
520 ])
521 .do_nothing()
522 .to_owned(),
523 )
524 .exec(&*tx)
525 .await?;
526 } else {
527 channel_members = Vec::new();
528 max_version = Vec::new();
529 }
530
531 let mut connections = Vec::new();
532 let mut rows = channel_buffer_collaborator::Entity::find()
533 .filter(
534 Condition::all()
535 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
536 )
537 .stream(&*tx)
538 .await?;
539 while let Some(row) = rows.next().await {
540 let row = row?;
541 connections.push(ConnectionId {
542 id: row.connection_id as u32,
543 owner_id: row.connection_server_id.0 as u32,
544 });
545 }
546
547 Ok((connections, channel_members, buffer.epoch, max_version))
548 })
549 .await
550 }
551
552 async fn save_max_operation(
553 &self,
554 user_id: UserId,
555 buffer_id: BufferId,
556 epoch: i32,
557 replica_id: i32,
558 lamport_timestamp: i32,
559 tx: &DatabaseTransaction,
560 ) -> Result<()> {
561 use observed_buffer_edits::Column;
562 observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
563 user_id: ActiveValue::Set(user_id),
564 buffer_id: ActiveValue::Set(buffer_id),
565 epoch: ActiveValue::Set(epoch),
566 replica_id: ActiveValue::Set(replica_id),
567 lamport_timestamp: ActiveValue::Set(lamport_timestamp),
568 })
569 .on_conflict(
570 OnConflict::columns([Column::UserId, Column::BufferId])
571 .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
572 .action_cond_where(
573 Condition::any().add(Column::Epoch.lt(epoch)).add(
574 Condition::all().add(Column::Epoch.eq(epoch)).add(
575 Condition::any()
576 .add(Column::LamportTimestamp.lt(lamport_timestamp))
577 .add(
578 Column::LamportTimestamp
579 .eq(lamport_timestamp)
580 .and(Column::ReplicaId.lt(replica_id)),
581 ),
582 ),
583 ),
584 )
585 .to_owned(),
586 )
587 .exec_without_returning(tx)
588 .await?;
589
590 Ok(())
591 }
592
593 async fn get_buffer_operation_serialization_version(
594 &self,
595 buffer_id: BufferId,
596 epoch: i32,
597 tx: &DatabaseTransaction,
598 ) -> Result<i32> {
599 Ok(buffer_snapshot::Entity::find()
600 .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
601 .filter(buffer_snapshot::Column::Epoch.eq(epoch))
602 .select_only()
603 .column(buffer_snapshot::Column::OperationSerializationVersion)
604 .into_values::<_, QueryOperationSerializationVersion>()
605 .one(tx)
606 .await?
607 .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
608 }
609
610 pub async fn get_channel_buffer(
611 &self,
612 channel_id: ChannelId,
613 tx: &DatabaseTransaction,
614 ) -> Result<buffer::Model> {
615 Ok(channel::Model {
616 id: channel_id,
617 ..Default::default()
618 }
619 .find_related(buffer::Entity)
620 .one(tx)
621 .await?
622 .ok_or_else(|| anyhow!("no such buffer"))?)
623 }
624
625 async fn get_buffer_state(
626 &self,
627 buffer: &buffer::Model,
628 tx: &DatabaseTransaction,
629 ) -> Result<(
630 String,
631 Vec<proto::Operation>,
632 Option<buffer_operation::Model>,
633 )> {
634 let id = buffer.id;
635 let (base_text, version) = if buffer.epoch > 0 {
636 let snapshot = buffer_snapshot::Entity::find()
637 .filter(
638 buffer_snapshot::Column::BufferId
639 .eq(id)
640 .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
641 )
642 .one(tx)
643 .await?
644 .ok_or_else(|| anyhow!("no such snapshot"))?;
645
646 let version = snapshot.operation_serialization_version;
647 (snapshot.text, version)
648 } else {
649 (String::new(), storage::SERIALIZATION_VERSION)
650 };
651
652 let mut rows = buffer_operation::Entity::find()
653 .filter(
654 buffer_operation::Column::BufferId
655 .eq(id)
656 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
657 )
658 .order_by_asc(buffer_operation::Column::LamportTimestamp)
659 .order_by_asc(buffer_operation::Column::ReplicaId)
660 .stream(tx)
661 .await?;
662
663 let mut operations = Vec::new();
664 let mut last_row = None;
665 while let Some(row) = rows.next().await {
666 let row = row?;
667 last_row = Some(buffer_operation::Model {
668 buffer_id: row.buffer_id,
669 epoch: row.epoch,
670 lamport_timestamp: row.lamport_timestamp,
671 replica_id: row.replica_id,
672 value: Default::default(),
673 });
674 operations.push(proto::Operation {
675 variant: Some(operation_from_storage(row, version)?),
676 });
677 }
678
679 Ok((base_text, operations, last_row))
680 }
681
682 async fn snapshot_channel_buffer(
683 &self,
684 channel_id: ChannelId,
685 tx: &DatabaseTransaction,
686 ) -> Result<()> {
687 let buffer = self.get_channel_buffer(channel_id, tx).await?;
688 let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
689 if operations.is_empty() {
690 return Ok(());
691 }
692
693 let mut text_buffer = text::Buffer::new(0, text::BufferId::new(1).unwrap(), base_text);
694 text_buffer
695 .apply_ops(operations.into_iter().filter_map(operation_from_wire))
696 .unwrap();
697
698 let base_text = text_buffer.text();
699 let epoch = buffer.epoch + 1;
700
701 buffer_snapshot::Model {
702 buffer_id: buffer.id,
703 epoch,
704 text: base_text,
705 operation_serialization_version: storage::SERIALIZATION_VERSION,
706 }
707 .into_active_model()
708 .insert(tx)
709 .await?;
710
711 buffer::ActiveModel {
712 id: ActiveValue::Unchanged(buffer.id),
713 epoch: ActiveValue::Set(epoch),
714 ..Default::default()
715 }
716 .save(tx)
717 .await?;
718
719 Ok(())
720 }
721
722 pub async fn observe_buffer_version(
723 &self,
724 buffer_id: BufferId,
725 user_id: UserId,
726 epoch: i32,
727 version: &[proto::VectorClockEntry],
728 ) -> Result<()> {
729 self.transaction(|tx| async move {
730 // For now, combine concurrent operations.
731 let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
732 return Ok(());
733 };
734 self.save_max_operation(
735 user_id,
736 buffer_id,
737 epoch,
738 component.replica_id as i32,
739 component.timestamp as i32,
740 &tx,
741 )
742 .await?;
743 Ok(())
744 })
745 .await
746 }
747
748 pub async fn latest_channel_buffer_changes(
749 &self,
750 channel_ids_by_buffer_id: &HashMap<BufferId, ChannelId>,
751 tx: &DatabaseTransaction,
752 ) -> Result<Vec<proto::ChannelBufferVersion>> {
753 let latest_operations = self
754 .get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), tx)
755 .await?;
756
757 Ok(latest_operations
758 .iter()
759 .flat_map(|op| {
760 Some(proto::ChannelBufferVersion {
761 channel_id: channel_ids_by_buffer_id.get(&op.buffer_id)?.to_proto(),
762 epoch: op.epoch as u64,
763 version: vec![proto::VectorClockEntry {
764 replica_id: op.replica_id as u32,
765 timestamp: op.lamport_timestamp as u32,
766 }],
767 })
768 })
769 .collect())
770 }
771
772 pub async fn observed_channel_buffer_changes(
773 &self,
774 channel_ids_by_buffer_id: &HashMap<BufferId, ChannelId>,
775 user_id: UserId,
776 tx: &DatabaseTransaction,
777 ) -> Result<Vec<proto::ChannelBufferVersion>> {
778 let observed_operations = observed_buffer_edits::Entity::find()
779 .filter(observed_buffer_edits::Column::UserId.eq(user_id))
780 .filter(
781 observed_buffer_edits::Column::BufferId
782 .is_in(channel_ids_by_buffer_id.keys().copied()),
783 )
784 .all(tx)
785 .await?;
786
787 Ok(observed_operations
788 .iter()
789 .flat_map(|op| {
790 Some(proto::ChannelBufferVersion {
791 channel_id: channel_ids_by_buffer_id.get(&op.buffer_id)?.to_proto(),
792 epoch: op.epoch as u64,
793 version: vec![proto::VectorClockEntry {
794 replica_id: op.replica_id as u32,
795 timestamp: op.lamport_timestamp as u32,
796 }],
797 })
798 })
799 .collect())
800 }
801
802 /// Returns the latest operations for the buffers with the specified IDs.
803 pub async fn get_latest_operations_for_buffers(
804 &self,
805 buffer_ids: impl IntoIterator<Item = BufferId>,
806 tx: &DatabaseTransaction,
807 ) -> Result<Vec<buffer_operation::Model>> {
808 let mut values = String::new();
809 for id in buffer_ids {
810 if !values.is_empty() {
811 values.push_str(", ");
812 }
813 write!(&mut values, "({})", id).unwrap();
814 }
815
816 if values.is_empty() {
817 return Ok(Vec::default());
818 }
819
820 let sql = format!(
821 r#"
822 SELECT
823 *
824 FROM
825 (
826 SELECT
827 *,
828 row_number() OVER (
829 PARTITION BY buffer_id
830 ORDER BY
831 epoch DESC,
832 lamport_timestamp DESC,
833 replica_id DESC
834 ) as row_number
835 FROM buffer_operations
836 WHERE
837 buffer_id in ({values})
838 ) AS last_operations
839 WHERE
840 row_number = 1
841 "#,
842 );
843
844 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
845 Ok(buffer_operation::Entity::find()
846 .from_raw_sql(stmt)
847 .all(tx)
848 .await?)
849 }
850}
851
852fn operation_to_storage(
853 operation: &proto::Operation,
854 buffer: &buffer::Model,
855 _format: i32,
856) -> Option<buffer_operation::ActiveModel> {
857 let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
858 proto::operation::Variant::Edit(operation) => (
859 operation.replica_id,
860 operation.lamport_timestamp,
861 storage::Operation {
862 version: version_to_storage(&operation.version),
863 is_undo: false,
864 edit_ranges: operation
865 .ranges
866 .iter()
867 .map(|range| storage::Range {
868 start: range.start,
869 end: range.end,
870 })
871 .collect(),
872 edit_texts: operation.new_text.clone(),
873 undo_counts: Vec::new(),
874 },
875 ),
876 proto::operation::Variant::Undo(operation) => (
877 operation.replica_id,
878 operation.lamport_timestamp,
879 storage::Operation {
880 version: version_to_storage(&operation.version),
881 is_undo: true,
882 edit_ranges: Vec::new(),
883 edit_texts: Vec::new(),
884 undo_counts: operation
885 .counts
886 .iter()
887 .map(|entry| storage::UndoCount {
888 replica_id: entry.replica_id,
889 lamport_timestamp: entry.lamport_timestamp,
890 count: entry.count,
891 })
892 .collect(),
893 },
894 ),
895 _ => None?,
896 };
897
898 Some(buffer_operation::ActiveModel {
899 buffer_id: ActiveValue::Set(buffer.id),
900 epoch: ActiveValue::Set(buffer.epoch),
901 replica_id: ActiveValue::Set(replica_id as i32),
902 lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
903 value: ActiveValue::Set(value.encode_to_vec()),
904 })
905}
906
907fn operation_from_storage(
908 row: buffer_operation::Model,
909 _format_version: i32,
910) -> Result<proto::operation::Variant, Error> {
911 let operation =
912 storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
913 let version = version_from_storage(&operation.version);
914 Ok(if operation.is_undo {
915 proto::operation::Variant::Undo(proto::operation::Undo {
916 replica_id: row.replica_id as u32,
917 lamport_timestamp: row.lamport_timestamp as u32,
918 version,
919 counts: operation
920 .undo_counts
921 .iter()
922 .map(|entry| proto::UndoCount {
923 replica_id: entry.replica_id,
924 lamport_timestamp: entry.lamport_timestamp,
925 count: entry.count,
926 })
927 .collect(),
928 })
929 } else {
930 proto::operation::Variant::Edit(proto::operation::Edit {
931 replica_id: row.replica_id as u32,
932 lamport_timestamp: row.lamport_timestamp as u32,
933 version,
934 ranges: operation
935 .edit_ranges
936 .into_iter()
937 .map(|range| proto::Range {
938 start: range.start,
939 end: range.end,
940 })
941 .collect(),
942 new_text: operation.edit_texts,
943 })
944 })
945}
946
947fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
948 version
949 .iter()
950 .map(|entry| storage::VectorClockEntry {
951 replica_id: entry.replica_id,
952 timestamp: entry.timestamp,
953 })
954 .collect()
955}
956
957fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
958 version
959 .iter()
960 .map(|entry| proto::VectorClockEntry {
961 replica_id: entry.replica_id,
962 timestamp: entry.timestamp,
963 })
964 .collect()
965}
966
967// This is currently a manual copy of the deserialization code in the client's language crate
968pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
969 match operation.variant? {
970 proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
971 timestamp: clock::Lamport {
972 replica_id: edit.replica_id as text::ReplicaId,
973 value: edit.lamport_timestamp,
974 },
975 version: version_from_wire(&edit.version),
976 ranges: edit
977 .ranges
978 .into_iter()
979 .map(|range| {
980 text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
981 })
982 .collect(),
983 new_text: edit.new_text.into_iter().map(Arc::from).collect(),
984 })),
985 proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
986 timestamp: clock::Lamport {
987 replica_id: undo.replica_id as text::ReplicaId,
988 value: undo.lamport_timestamp,
989 },
990 version: version_from_wire(&undo.version),
991 counts: undo
992 .counts
993 .into_iter()
994 .map(|c| {
995 (
996 clock::Lamport {
997 replica_id: c.replica_id as text::ReplicaId,
998 value: c.lamport_timestamp,
999 },
1000 c.count,
1001 )
1002 })
1003 .collect(),
1004 })),
1005 _ => None,
1006 }
1007}
1008
1009fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
1010 let mut version = clock::Global::new();
1011 for entry in message {
1012 version.observe(clock::Lamport {
1013 replica_id: entry.replica_id as text::ReplicaId,
1014 value: entry.timestamp,
1015 });
1016 }
1017 version
1018}
1019
1020fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
1021 let mut message = Vec::new();
1022 for entry in version.iter() {
1023 message.push(proto::VectorClockEntry {
1024 replica_id: entry.replica_id as u32,
1025 timestamp: entry.value,
1026 });
1027 }
1028 message
1029}
1030
1031#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
1032enum QueryOperationSerializationVersion {
1033 OperationSerializationVersion,
1034}
1035
1036mod storage {
1037 #![allow(non_snake_case)]
1038 use prost::Message;
1039 pub const SERIALIZATION_VERSION: i32 = 1;
1040
1041 #[derive(Message)]
1042 pub struct Operation {
1043 #[prost(message, repeated, tag = "2")]
1044 pub version: Vec<VectorClockEntry>,
1045 #[prost(bool, tag = "3")]
1046 pub is_undo: bool,
1047 #[prost(message, repeated, tag = "4")]
1048 pub edit_ranges: Vec<Range>,
1049 #[prost(string, repeated, tag = "5")]
1050 pub edit_texts: Vec<String>,
1051 #[prost(message, repeated, tag = "6")]
1052 pub undo_counts: Vec<UndoCount>,
1053 }
1054
1055 #[derive(Message)]
1056 pub struct VectorClockEntry {
1057 #[prost(uint32, tag = "1")]
1058 pub replica_id: u32,
1059 #[prost(uint32, tag = "2")]
1060 pub timestamp: u32,
1061 }
1062
1063 #[derive(Message)]
1064 pub struct Range {
1065 #[prost(uint64, tag = "1")]
1066 pub start: u64,
1067 #[prost(uint64, tag = "2")]
1068 pub end: u64,
1069 }
1070
1071 #[derive(Message)]
1072 pub struct UndoCount {
1073 #[prost(uint32, tag = "1")]
1074 pub replica_id: u32,
1075 #[prost(uint32, tag = "2")]
1076 pub lamport_timestamp: u32,
1077 #[prost(uint32, tag = "3")]
1078 pub count: u32,
1079 }
1080}