1use super::*;
2use prost::Message;
3use text::{EditOperation, UndoOperation};
4
5pub struct LeftChannelBuffer {
6 pub channel_id: ChannelId,
7 pub collaborators: Vec<proto::Collaborator>,
8 pub connections: Vec<ConnectionId>,
9}
10
11impl Database {
12 /// Open a channel buffer. Returns the current contents, and adds you to the list of people
13 /// to notify on changes.
14 pub async fn join_channel_buffer(
15 &self,
16 channel_id: ChannelId,
17 user_id: UserId,
18 connection: ConnectionId,
19 ) -> Result<proto::JoinChannelBufferResponse> {
20 self.transaction(|tx| async move {
21 let channel = self.get_channel_internal(channel_id, &*tx).await?;
22 self.check_user_is_channel_participant(&channel, user_id, &tx)
23 .await?;
24
25 let buffer = channel::Model {
26 id: channel_id,
27 ..Default::default()
28 }
29 .find_related(buffer::Entity)
30 .one(&*tx)
31 .await?;
32
33 let buffer = if let Some(buffer) = buffer {
34 buffer
35 } else {
36 let buffer = buffer::ActiveModel {
37 channel_id: ActiveValue::Set(channel_id),
38 ..Default::default()
39 }
40 .insert(&*tx)
41 .await?;
42 buffer_snapshot::ActiveModel {
43 buffer_id: ActiveValue::Set(buffer.id),
44 epoch: ActiveValue::Set(0),
45 text: ActiveValue::Set(String::new()),
46 operation_serialization_version: ActiveValue::Set(
47 storage::SERIALIZATION_VERSION,
48 ),
49 }
50 .insert(&*tx)
51 .await?;
52 buffer
53 };
54
55 // Join the collaborators
56 let mut collaborators = channel_buffer_collaborator::Entity::find()
57 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
58 .all(&*tx)
59 .await?;
60 let replica_ids = collaborators
61 .iter()
62 .map(|c| c.replica_id)
63 .collect::<HashSet<_>>();
64 let mut replica_id = ReplicaId(0);
65 while replica_ids.contains(&replica_id) {
66 replica_id.0 += 1;
67 }
68 let collaborator = channel_buffer_collaborator::ActiveModel {
69 channel_id: ActiveValue::Set(channel_id),
70 connection_id: ActiveValue::Set(connection.id as i32),
71 connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
72 user_id: ActiveValue::Set(user_id),
73 replica_id: ActiveValue::Set(replica_id),
74 ..Default::default()
75 }
76 .insert(&*tx)
77 .await?;
78 collaborators.push(collaborator);
79
80 let (base_text, operations, max_operation) =
81 self.get_buffer_state(&buffer, &tx).await?;
82
83 // Save the last observed operation
84 if let Some(op) = max_operation {
85 observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
86 user_id: ActiveValue::Set(user_id),
87 buffer_id: ActiveValue::Set(buffer.id),
88 epoch: ActiveValue::Set(op.epoch),
89 lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
90 replica_id: ActiveValue::Set(op.replica_id),
91 })
92 .on_conflict(
93 OnConflict::columns([
94 observed_buffer_edits::Column::UserId,
95 observed_buffer_edits::Column::BufferId,
96 ])
97 .update_columns([
98 observed_buffer_edits::Column::Epoch,
99 observed_buffer_edits::Column::LamportTimestamp,
100 ])
101 .to_owned(),
102 )
103 .exec(&*tx)
104 .await?;
105 }
106
107 Ok(proto::JoinChannelBufferResponse {
108 buffer_id: buffer.id.to_proto(),
109 replica_id: replica_id.to_proto() as u32,
110 base_text,
111 operations,
112 epoch: buffer.epoch as u64,
113 collaborators: collaborators
114 .into_iter()
115 .map(|collaborator| proto::Collaborator {
116 peer_id: Some(collaborator.connection().into()),
117 user_id: collaborator.user_id.to_proto(),
118 replica_id: collaborator.replica_id.0 as u32,
119 })
120 .collect(),
121 })
122 })
123 .await
124 }
125
126 /// Rejoin a channel buffer (after a connection interruption)
127 pub async fn rejoin_channel_buffers(
128 &self,
129 buffers: &[proto::ChannelBufferVersion],
130 user_id: UserId,
131 connection_id: ConnectionId,
132 ) -> Result<Vec<RejoinedChannelBuffer>> {
133 self.transaction(|tx| async move {
134 let mut results = Vec::new();
135 for client_buffer in buffers {
136 let channel = self
137 .get_channel_internal(ChannelId::from_proto(client_buffer.channel_id), &*tx)
138 .await?;
139 if self
140 .check_user_is_channel_participant(&channel, user_id, &*tx)
141 .await
142 .is_err()
143 {
144 log::info!("user is not a member of channel");
145 continue;
146 }
147
148 let buffer = self.get_channel_buffer(channel.id, &*tx).await?;
149 let mut collaborators = channel_buffer_collaborator::Entity::find()
150 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel.id))
151 .all(&*tx)
152 .await?;
153
154 // If the buffer epoch hasn't changed since the client lost
155 // connection, then the client's buffer can be synchronized with
156 // the server's buffer.
157 if buffer.epoch as u64 != client_buffer.epoch {
158 log::info!("can't rejoin buffer, epoch has changed");
159 continue;
160 }
161
162 // Find the collaborator record for this user's previous lost
163 // connection. Update it with the new connection id.
164 let server_id = ServerId(connection_id.owner_id as i32);
165 let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
166 c.user_id == user_id
167 && (c.connection_lost || c.connection_server_id != server_id)
168 }) else {
169 log::info!("can't rejoin buffer, no previous collaborator found");
170 continue;
171 };
172 let old_connection_id = self_collaborator.connection();
173 *self_collaborator = channel_buffer_collaborator::ActiveModel {
174 id: ActiveValue::Unchanged(self_collaborator.id),
175 connection_id: ActiveValue::Set(connection_id.id as i32),
176 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
177 connection_lost: ActiveValue::Set(false),
178 ..Default::default()
179 }
180 .update(&*tx)
181 .await?;
182
183 let client_version = version_from_wire(&client_buffer.version);
184 let serialization_version = self
185 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
186 .await?;
187
188 let mut rows = buffer_operation::Entity::find()
189 .filter(
190 buffer_operation::Column::BufferId
191 .eq(buffer.id)
192 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
193 )
194 .stream(&*tx)
195 .await?;
196
197 // Find the server's version vector and any operations
198 // that the client has not seen.
199 let mut server_version = clock::Global::new();
200 let mut operations = Vec::new();
201 while let Some(row) = rows.next().await {
202 let row = row?;
203 let timestamp = clock::Lamport {
204 replica_id: row.replica_id as u16,
205 value: row.lamport_timestamp as u32,
206 };
207 server_version.observe(timestamp);
208 if !client_version.observed(timestamp) {
209 operations.push(proto::Operation {
210 variant: Some(operation_from_storage(row, serialization_version)?),
211 })
212 }
213 }
214
215 results.push(RejoinedChannelBuffer {
216 old_connection_id,
217 buffer: proto::RejoinedChannelBuffer {
218 channel_id: client_buffer.channel_id,
219 version: version_to_wire(&server_version),
220 operations,
221 collaborators: collaborators
222 .into_iter()
223 .map(|collaborator| proto::Collaborator {
224 peer_id: Some(collaborator.connection().into()),
225 user_id: collaborator.user_id.to_proto(),
226 replica_id: collaborator.replica_id.0 as u32,
227 })
228 .collect(),
229 },
230 });
231 }
232
233 Ok(results)
234 })
235 .await
236 }
237
238 /// Clear out any buffer collaborators who are no longer collaborating.
239 pub async fn clear_stale_channel_buffer_collaborators(
240 &self,
241 channel_id: ChannelId,
242 server_id: ServerId,
243 ) -> Result<RefreshedChannelBuffer> {
244 self.transaction(|tx| async move {
245 let db_collaborators = channel_buffer_collaborator::Entity::find()
246 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
247 .all(&*tx)
248 .await?;
249
250 let mut connection_ids = Vec::new();
251 let mut collaborators = Vec::new();
252 let mut collaborator_ids_to_remove = Vec::new();
253 for db_collaborator in &db_collaborators {
254 if !db_collaborator.connection_lost
255 && db_collaborator.connection_server_id == server_id
256 {
257 connection_ids.push(db_collaborator.connection());
258 collaborators.push(proto::Collaborator {
259 peer_id: Some(db_collaborator.connection().into()),
260 replica_id: db_collaborator.replica_id.0 as u32,
261 user_id: db_collaborator.user_id.to_proto(),
262 })
263 } else {
264 collaborator_ids_to_remove.push(db_collaborator.id);
265 }
266 }
267
268 channel_buffer_collaborator::Entity::delete_many()
269 .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
270 .exec(&*tx)
271 .await?;
272
273 Ok(RefreshedChannelBuffer {
274 connection_ids,
275 collaborators,
276 })
277 })
278 .await
279 }
280
281 /// Close the channel buffer, and stop receiving updates for it.
282 pub async fn leave_channel_buffer(
283 &self,
284 channel_id: ChannelId,
285 connection: ConnectionId,
286 ) -> Result<LeftChannelBuffer> {
287 self.transaction(|tx| async move {
288 self.leave_channel_buffer_internal(channel_id, connection, &*tx)
289 .await
290 })
291 .await
292 }
293
294 /// Close the channel buffer, and stop receiving updates for it.
295 pub async fn channel_buffer_connection_lost(
296 &self,
297 connection: ConnectionId,
298 tx: &DatabaseTransaction,
299 ) -> Result<()> {
300 channel_buffer_collaborator::Entity::update_many()
301 .filter(
302 Condition::all()
303 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
304 .add(
305 channel_buffer_collaborator::Column::ConnectionServerId
306 .eq(connection.owner_id as i32),
307 ),
308 )
309 .set(channel_buffer_collaborator::ActiveModel {
310 connection_lost: ActiveValue::set(true),
311 ..Default::default()
312 })
313 .exec(&*tx)
314 .await?;
315 Ok(())
316 }
317
318 /// Close all open channel buffers
319 pub async fn leave_channel_buffers(
320 &self,
321 connection: ConnectionId,
322 ) -> Result<Vec<LeftChannelBuffer>> {
323 self.transaction(|tx| async move {
324 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
325 enum QueryChannelIds {
326 ChannelId,
327 }
328
329 let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
330 .select_only()
331 .column(channel_buffer_collaborator::Column::ChannelId)
332 .filter(Condition::all().add(
333 channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
334 ))
335 .into_values::<_, QueryChannelIds>()
336 .all(&*tx)
337 .await?;
338
339 let mut result = Vec::new();
340 for channel_id in channel_ids {
341 let left_channel_buffer = self
342 .leave_channel_buffer_internal(channel_id, connection, &*tx)
343 .await?;
344 result.push(left_channel_buffer);
345 }
346
347 Ok(result)
348 })
349 .await
350 }
351
352 async fn leave_channel_buffer_internal(
353 &self,
354 channel_id: ChannelId,
355 connection: ConnectionId,
356 tx: &DatabaseTransaction,
357 ) -> Result<LeftChannelBuffer> {
358 let result = channel_buffer_collaborator::Entity::delete_many()
359 .filter(
360 Condition::all()
361 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
362 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
363 .add(
364 channel_buffer_collaborator::Column::ConnectionServerId
365 .eq(connection.owner_id as i32),
366 ),
367 )
368 .exec(&*tx)
369 .await?;
370 if result.rows_affected == 0 {
371 Err(anyhow!("not a collaborator on this project"))?;
372 }
373
374 let mut collaborators = Vec::new();
375 let mut connections = Vec::new();
376 let mut rows = channel_buffer_collaborator::Entity::find()
377 .filter(
378 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
379 )
380 .stream(&*tx)
381 .await?;
382 while let Some(row) = rows.next().await {
383 let row = row?;
384 let connection = row.connection();
385 connections.push(connection);
386 collaborators.push(proto::Collaborator {
387 peer_id: Some(connection.into()),
388 replica_id: row.replica_id.0 as u32,
389 user_id: row.user_id.to_proto(),
390 });
391 }
392
393 drop(rows);
394
395 if collaborators.is_empty() {
396 self.snapshot_channel_buffer(channel_id, &tx).await?;
397 }
398
399 Ok(LeftChannelBuffer {
400 channel_id,
401 collaborators,
402 connections,
403 })
404 }
405
406 pub async fn get_channel_buffer_collaborators(
407 &self,
408 channel_id: ChannelId,
409 ) -> Result<Vec<UserId>> {
410 self.transaction(|tx| async move {
411 self.get_channel_buffer_collaborators_internal(channel_id, &*tx)
412 .await
413 })
414 .await
415 }
416
417 async fn get_channel_buffer_collaborators_internal(
418 &self,
419 channel_id: ChannelId,
420 tx: &DatabaseTransaction,
421 ) -> Result<Vec<UserId>> {
422 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
423 enum QueryUserIds {
424 UserId,
425 }
426
427 let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
428 .select_only()
429 .column(channel_buffer_collaborator::Column::UserId)
430 .filter(
431 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
432 )
433 .into_values::<_, QueryUserIds>()
434 .all(&*tx)
435 .await?;
436
437 Ok(users)
438 }
439
440 pub async fn update_channel_buffer(
441 &self,
442 channel_id: ChannelId,
443 user: UserId,
444 operations: &[proto::Operation],
445 ) -> Result<(
446 Vec<ConnectionId>,
447 Vec<UserId>,
448 i32,
449 Vec<proto::VectorClockEntry>,
450 )> {
451 self.transaction(move |tx| async move {
452 let channel = self.get_channel_internal(channel_id, &*tx).await?;
453
454 let mut requires_write_permission = false;
455 for op in operations.iter() {
456 match op.variant {
457 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
458 Some(_) => requires_write_permission = true,
459 }
460 }
461 if requires_write_permission {
462 self.check_user_is_channel_member(&channel, user, &*tx)
463 .await?;
464 } else {
465 self.check_user_is_channel_participant(&channel, user, &*tx)
466 .await?;
467 }
468
469 let buffer = buffer::Entity::find()
470 .filter(buffer::Column::ChannelId.eq(channel_id))
471 .one(&*tx)
472 .await?
473 .ok_or_else(|| anyhow!("no such buffer"))?;
474
475 let serialization_version = self
476 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
477 .await?;
478
479 let operations = operations
480 .iter()
481 .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
482 .collect::<Vec<_>>();
483
484 let mut channel_members;
485 let max_version;
486
487 if !operations.is_empty() {
488 let max_operation = operations
489 .iter()
490 .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
491 .unwrap();
492
493 max_version = vec![proto::VectorClockEntry {
494 replica_id: *max_operation.replica_id.as_ref() as u32,
495 timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
496 }];
497
498 // get current channel participants and save the max operation above
499 self.save_max_operation(
500 user,
501 buffer.id,
502 buffer.epoch,
503 *max_operation.replica_id.as_ref(),
504 *max_operation.lamport_timestamp.as_ref(),
505 &*tx,
506 )
507 .await?;
508
509 channel_members = self.get_channel_participants(&channel, &*tx).await?;
510 let collaborators = self
511 .get_channel_buffer_collaborators_internal(channel_id, &*tx)
512 .await?;
513 channel_members.retain(|member| !collaborators.contains(member));
514
515 buffer_operation::Entity::insert_many(operations)
516 .on_conflict(
517 OnConflict::columns([
518 buffer_operation::Column::BufferId,
519 buffer_operation::Column::Epoch,
520 buffer_operation::Column::LamportTimestamp,
521 buffer_operation::Column::ReplicaId,
522 ])
523 .do_nothing()
524 .to_owned(),
525 )
526 .exec(&*tx)
527 .await?;
528 } else {
529 channel_members = Vec::new();
530 max_version = Vec::new();
531 }
532
533 let mut connections = Vec::new();
534 let mut rows = channel_buffer_collaborator::Entity::find()
535 .filter(
536 Condition::all()
537 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
538 )
539 .stream(&*tx)
540 .await?;
541 while let Some(row) = rows.next().await {
542 let row = row?;
543 connections.push(ConnectionId {
544 id: row.connection_id as u32,
545 owner_id: row.connection_server_id.0 as u32,
546 });
547 }
548
549 Ok((connections, channel_members, buffer.epoch, max_version))
550 })
551 .await
552 }
553
554 async fn save_max_operation(
555 &self,
556 user_id: UserId,
557 buffer_id: BufferId,
558 epoch: i32,
559 replica_id: i32,
560 lamport_timestamp: i32,
561 tx: &DatabaseTransaction,
562 ) -> Result<()> {
563 use observed_buffer_edits::Column;
564 observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
565 user_id: ActiveValue::Set(user_id),
566 buffer_id: ActiveValue::Set(buffer_id),
567 epoch: ActiveValue::Set(epoch),
568 replica_id: ActiveValue::Set(replica_id),
569 lamport_timestamp: ActiveValue::Set(lamport_timestamp),
570 })
571 .on_conflict(
572 OnConflict::columns([Column::UserId, Column::BufferId])
573 .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
574 .action_cond_where(
575 Condition::any().add(Column::Epoch.lt(epoch)).add(
576 Condition::all().add(Column::Epoch.eq(epoch)).add(
577 Condition::any()
578 .add(Column::LamportTimestamp.lt(lamport_timestamp))
579 .add(
580 Column::LamportTimestamp
581 .eq(lamport_timestamp)
582 .and(Column::ReplicaId.lt(replica_id)),
583 ),
584 ),
585 ),
586 )
587 .to_owned(),
588 )
589 .exec_without_returning(tx)
590 .await?;
591
592 Ok(())
593 }
594
595 async fn get_buffer_operation_serialization_version(
596 &self,
597 buffer_id: BufferId,
598 epoch: i32,
599 tx: &DatabaseTransaction,
600 ) -> Result<i32> {
601 Ok(buffer_snapshot::Entity::find()
602 .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
603 .filter(buffer_snapshot::Column::Epoch.eq(epoch))
604 .select_only()
605 .column(buffer_snapshot::Column::OperationSerializationVersion)
606 .into_values::<_, QueryOperationSerializationVersion>()
607 .one(&*tx)
608 .await?
609 .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
610 }
611
612 pub async fn get_channel_buffer(
613 &self,
614 channel_id: ChannelId,
615 tx: &DatabaseTransaction,
616 ) -> Result<buffer::Model> {
617 Ok(channel::Model {
618 id: channel_id,
619 ..Default::default()
620 }
621 .find_related(buffer::Entity)
622 .one(&*tx)
623 .await?
624 .ok_or_else(|| anyhow!("no such buffer"))?)
625 }
626
627 async fn get_buffer_state(
628 &self,
629 buffer: &buffer::Model,
630 tx: &DatabaseTransaction,
631 ) -> Result<(
632 String,
633 Vec<proto::Operation>,
634 Option<buffer_operation::Model>,
635 )> {
636 let id = buffer.id;
637 let (base_text, version) = if buffer.epoch > 0 {
638 let snapshot = buffer_snapshot::Entity::find()
639 .filter(
640 buffer_snapshot::Column::BufferId
641 .eq(id)
642 .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
643 )
644 .one(&*tx)
645 .await?
646 .ok_or_else(|| anyhow!("no such snapshot"))?;
647
648 let version = snapshot.operation_serialization_version;
649 (snapshot.text, version)
650 } else {
651 (String::new(), storage::SERIALIZATION_VERSION)
652 };
653
654 let mut rows = buffer_operation::Entity::find()
655 .filter(
656 buffer_operation::Column::BufferId
657 .eq(id)
658 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
659 )
660 .order_by_asc(buffer_operation::Column::LamportTimestamp)
661 .order_by_asc(buffer_operation::Column::ReplicaId)
662 .stream(&*tx)
663 .await?;
664
665 let mut operations = Vec::new();
666 let mut last_row = None;
667 while let Some(row) = rows.next().await {
668 let row = row?;
669 last_row = Some(buffer_operation::Model {
670 buffer_id: row.buffer_id,
671 epoch: row.epoch,
672 lamport_timestamp: row.lamport_timestamp,
673 replica_id: row.replica_id,
674 value: Default::default(),
675 });
676 operations.push(proto::Operation {
677 variant: Some(operation_from_storage(row, version)?),
678 });
679 }
680
681 Ok((base_text, operations, last_row))
682 }
683
684 async fn snapshot_channel_buffer(
685 &self,
686 channel_id: ChannelId,
687 tx: &DatabaseTransaction,
688 ) -> Result<()> {
689 let buffer = self.get_channel_buffer(channel_id, tx).await?;
690 let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
691 if operations.is_empty() {
692 return Ok(());
693 }
694
695 let mut text_buffer = text::Buffer::new(0, text::BufferId::new(1).unwrap(), base_text);
696 text_buffer
697 .apply_ops(operations.into_iter().filter_map(operation_from_wire))
698 .unwrap();
699
700 let base_text = text_buffer.text();
701 let epoch = buffer.epoch + 1;
702
703 buffer_snapshot::Model {
704 buffer_id: buffer.id,
705 epoch,
706 text: base_text,
707 operation_serialization_version: storage::SERIALIZATION_VERSION,
708 }
709 .into_active_model()
710 .insert(tx)
711 .await?;
712
713 buffer::ActiveModel {
714 id: ActiveValue::Unchanged(buffer.id),
715 epoch: ActiveValue::Set(epoch),
716 ..Default::default()
717 }
718 .save(tx)
719 .await?;
720
721 Ok(())
722 }
723
724 pub async fn observe_buffer_version(
725 &self,
726 buffer_id: BufferId,
727 user_id: UserId,
728 epoch: i32,
729 version: &[proto::VectorClockEntry],
730 ) -> Result<()> {
731 self.transaction(|tx| async move {
732 // For now, combine concurrent operations.
733 let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
734 return Ok(());
735 };
736 self.save_max_operation(
737 user_id,
738 buffer_id,
739 epoch,
740 component.replica_id as i32,
741 component.timestamp as i32,
742 &*tx,
743 )
744 .await?;
745 Ok(())
746 })
747 .await
748 }
749
750 pub async fn latest_channel_buffer_changes(
751 &self,
752 channel_ids_by_buffer_id: &HashMap<BufferId, ChannelId>,
753 tx: &DatabaseTransaction,
754 ) -> Result<Vec<proto::ChannelBufferVersion>> {
755 let latest_operations = self
756 .get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
757 .await?;
758
759 Ok(latest_operations
760 .iter()
761 .flat_map(|op| {
762 Some(proto::ChannelBufferVersion {
763 channel_id: channel_ids_by_buffer_id.get(&op.buffer_id)?.to_proto(),
764 epoch: op.epoch as u64,
765 version: vec![proto::VectorClockEntry {
766 replica_id: op.replica_id as u32,
767 timestamp: op.lamport_timestamp as u32,
768 }],
769 })
770 })
771 .collect())
772 }
773
774 pub async fn observed_channel_buffer_changes(
775 &self,
776 channel_ids_by_buffer_id: &HashMap<BufferId, ChannelId>,
777 user_id: UserId,
778 tx: &DatabaseTransaction,
779 ) -> Result<Vec<proto::ChannelBufferVersion>> {
780 let observed_operations = observed_buffer_edits::Entity::find()
781 .filter(observed_buffer_edits::Column::UserId.eq(user_id))
782 .filter(
783 observed_buffer_edits::Column::BufferId
784 .is_in(channel_ids_by_buffer_id.keys().copied()),
785 )
786 .all(&*tx)
787 .await?;
788
789 Ok(observed_operations
790 .iter()
791 .flat_map(|op| {
792 Some(proto::ChannelBufferVersion {
793 channel_id: channel_ids_by_buffer_id.get(&op.buffer_id)?.to_proto(),
794 epoch: op.epoch as u64,
795 version: vec![proto::VectorClockEntry {
796 replica_id: op.replica_id as u32,
797 timestamp: op.lamport_timestamp as u32,
798 }],
799 })
800 })
801 .collect())
802 }
803
804 /// Returns the latest operations for the buffers with the specified IDs.
805 pub async fn get_latest_operations_for_buffers(
806 &self,
807 buffer_ids: impl IntoIterator<Item = BufferId>,
808 tx: &DatabaseTransaction,
809 ) -> Result<Vec<buffer_operation::Model>> {
810 let mut values = String::new();
811 for id in buffer_ids {
812 if !values.is_empty() {
813 values.push_str(", ");
814 }
815 write!(&mut values, "({})", id).unwrap();
816 }
817
818 if values.is_empty() {
819 return Ok(Vec::default());
820 }
821
822 let sql = format!(
823 r#"
824 SELECT
825 *
826 FROM
827 (
828 SELECT
829 *,
830 row_number() OVER (
831 PARTITION BY buffer_id
832 ORDER BY
833 epoch DESC,
834 lamport_timestamp DESC,
835 replica_id DESC
836 ) as row_number
837 FROM buffer_operations
838 WHERE
839 buffer_id in ({values})
840 ) AS last_operations
841 WHERE
842 row_number = 1
843 "#,
844 );
845
846 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
847 Ok(buffer_operation::Entity::find()
848 .from_raw_sql(stmt)
849 .all(&*tx)
850 .await?)
851 }
852}
853
854fn operation_to_storage(
855 operation: &proto::Operation,
856 buffer: &buffer::Model,
857 _format: i32,
858) -> Option<buffer_operation::ActiveModel> {
859 let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
860 proto::operation::Variant::Edit(operation) => (
861 operation.replica_id,
862 operation.lamport_timestamp,
863 storage::Operation {
864 version: version_to_storage(&operation.version),
865 is_undo: false,
866 edit_ranges: operation
867 .ranges
868 .iter()
869 .map(|range| storage::Range {
870 start: range.start,
871 end: range.end,
872 })
873 .collect(),
874 edit_texts: operation.new_text.clone(),
875 undo_counts: Vec::new(),
876 },
877 ),
878 proto::operation::Variant::Undo(operation) => (
879 operation.replica_id,
880 operation.lamport_timestamp,
881 storage::Operation {
882 version: version_to_storage(&operation.version),
883 is_undo: true,
884 edit_ranges: Vec::new(),
885 edit_texts: Vec::new(),
886 undo_counts: operation
887 .counts
888 .iter()
889 .map(|entry| storage::UndoCount {
890 replica_id: entry.replica_id,
891 lamport_timestamp: entry.lamport_timestamp,
892 count: entry.count,
893 })
894 .collect(),
895 },
896 ),
897 _ => None?,
898 };
899
900 Some(buffer_operation::ActiveModel {
901 buffer_id: ActiveValue::Set(buffer.id),
902 epoch: ActiveValue::Set(buffer.epoch),
903 replica_id: ActiveValue::Set(replica_id as i32),
904 lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
905 value: ActiveValue::Set(value.encode_to_vec()),
906 })
907}
908
909fn operation_from_storage(
910 row: buffer_operation::Model,
911 _format_version: i32,
912) -> Result<proto::operation::Variant, Error> {
913 let operation =
914 storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
915 let version = version_from_storage(&operation.version);
916 Ok(if operation.is_undo {
917 proto::operation::Variant::Undo(proto::operation::Undo {
918 replica_id: row.replica_id as u32,
919 lamport_timestamp: row.lamport_timestamp as u32,
920 version,
921 counts: operation
922 .undo_counts
923 .iter()
924 .map(|entry| proto::UndoCount {
925 replica_id: entry.replica_id,
926 lamport_timestamp: entry.lamport_timestamp,
927 count: entry.count,
928 })
929 .collect(),
930 })
931 } else {
932 proto::operation::Variant::Edit(proto::operation::Edit {
933 replica_id: row.replica_id as u32,
934 lamport_timestamp: row.lamport_timestamp as u32,
935 version,
936 ranges: operation
937 .edit_ranges
938 .into_iter()
939 .map(|range| proto::Range {
940 start: range.start,
941 end: range.end,
942 })
943 .collect(),
944 new_text: operation.edit_texts,
945 })
946 })
947}
948
949fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
950 version
951 .iter()
952 .map(|entry| storage::VectorClockEntry {
953 replica_id: entry.replica_id,
954 timestamp: entry.timestamp,
955 })
956 .collect()
957}
958
959fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
960 version
961 .iter()
962 .map(|entry| proto::VectorClockEntry {
963 replica_id: entry.replica_id,
964 timestamp: entry.timestamp,
965 })
966 .collect()
967}
968
969// This is currently a manual copy of the deserialization code in the client's language crate
970pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
971 match operation.variant? {
972 proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
973 timestamp: clock::Lamport {
974 replica_id: edit.replica_id as text::ReplicaId,
975 value: edit.lamport_timestamp,
976 },
977 version: version_from_wire(&edit.version),
978 ranges: edit
979 .ranges
980 .into_iter()
981 .map(|range| {
982 text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
983 })
984 .collect(),
985 new_text: edit.new_text.into_iter().map(Arc::from).collect(),
986 })),
987 proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
988 timestamp: clock::Lamport {
989 replica_id: undo.replica_id as text::ReplicaId,
990 value: undo.lamport_timestamp,
991 },
992 version: version_from_wire(&undo.version),
993 counts: undo
994 .counts
995 .into_iter()
996 .map(|c| {
997 (
998 clock::Lamport {
999 replica_id: c.replica_id as text::ReplicaId,
1000 value: c.lamport_timestamp,
1001 },
1002 c.count,
1003 )
1004 })
1005 .collect(),
1006 })),
1007 _ => None,
1008 }
1009}
1010
1011fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
1012 let mut version = clock::Global::new();
1013 for entry in message {
1014 version.observe(clock::Lamport {
1015 replica_id: entry.replica_id as text::ReplicaId,
1016 value: entry.timestamp,
1017 });
1018 }
1019 version
1020}
1021
1022fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
1023 let mut message = Vec::new();
1024 for entry in version.iter() {
1025 message.push(proto::VectorClockEntry {
1026 replica_id: entry.replica_id as u32,
1027 timestamp: entry.value,
1028 });
1029 }
1030 message
1031}
1032
1033#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
1034enum QueryOperationSerializationVersion {
1035 OperationSerializationVersion,
1036}
1037
1038mod storage {
1039 #![allow(non_snake_case)]
1040 use prost::Message;
1041 pub const SERIALIZATION_VERSION: i32 = 1;
1042
1043 #[derive(Message)]
1044 pub struct Operation {
1045 #[prost(message, repeated, tag = "2")]
1046 pub version: Vec<VectorClockEntry>,
1047 #[prost(bool, tag = "3")]
1048 pub is_undo: bool,
1049 #[prost(message, repeated, tag = "4")]
1050 pub edit_ranges: Vec<Range>,
1051 #[prost(string, repeated, tag = "5")]
1052 pub edit_texts: Vec<String>,
1053 #[prost(message, repeated, tag = "6")]
1054 pub undo_counts: Vec<UndoCount>,
1055 }
1056
1057 #[derive(Message)]
1058 pub struct VectorClockEntry {
1059 #[prost(uint32, tag = "1")]
1060 pub replica_id: u32,
1061 #[prost(uint32, tag = "2")]
1062 pub timestamp: u32,
1063 }
1064
1065 #[derive(Message)]
1066 pub struct Range {
1067 #[prost(uint64, tag = "1")]
1068 pub start: u64,
1069 #[prost(uint64, tag = "2")]
1070 pub end: u64,
1071 }
1072
1073 #[derive(Message)]
1074 pub struct UndoCount {
1075 #[prost(uint32, tag = "1")]
1076 pub replica_id: u32,
1077 #[prost(uint32, tag = "2")]
1078 pub lamport_timestamp: u32,
1079 #[prost(uint32, tag = "3")]
1080 pub count: u32,
1081 }
1082}