1use super::*;
2use prost::Message;
3use text::{EditOperation, UndoOperation};
4
5pub struct LeftChannelBuffer {
6 pub channel_id: ChannelId,
7 pub collaborators: Vec<proto::Collaborator>,
8 pub connections: Vec<ConnectionId>,
9}
10
11impl Database {
12 /// Open a channel buffer. Returns the current contents, and adds you to the list of people
13 /// to notify on changes.
14 pub async fn join_channel_buffer(
15 &self,
16 channel_id: ChannelId,
17 user_id: UserId,
18 connection: ConnectionId,
19 ) -> Result<proto::JoinChannelBufferResponse> {
20 self.transaction(|tx| async move {
21 let channel = self.get_channel_internal(channel_id, &*tx).await?;
22 self.check_user_is_channel_participant(&channel, user_id, &tx)
23 .await?;
24
25 let buffer = channel::Model {
26 id: channel_id,
27 ..Default::default()
28 }
29 .find_related(buffer::Entity)
30 .one(&*tx)
31 .await?;
32
33 let buffer = if let Some(buffer) = buffer {
34 buffer
35 } else {
36 let buffer = buffer::ActiveModel {
37 channel_id: ActiveValue::Set(channel_id),
38 ..Default::default()
39 }
40 .insert(&*tx)
41 .await?;
42 buffer_snapshot::ActiveModel {
43 buffer_id: ActiveValue::Set(buffer.id),
44 epoch: ActiveValue::Set(0),
45 text: ActiveValue::Set(String::new()),
46 operation_serialization_version: ActiveValue::Set(
47 storage::SERIALIZATION_VERSION,
48 ),
49 }
50 .insert(&*tx)
51 .await?;
52 buffer
53 };
54
55 // Join the collaborators
56 let mut collaborators = channel_buffer_collaborator::Entity::find()
57 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
58 .all(&*tx)
59 .await?;
60 let replica_ids = collaborators
61 .iter()
62 .map(|c| c.replica_id)
63 .collect::<HashSet<_>>();
64 let mut replica_id = ReplicaId(0);
65 while replica_ids.contains(&replica_id) {
66 replica_id.0 += 1;
67 }
68 let collaborator = channel_buffer_collaborator::ActiveModel {
69 channel_id: ActiveValue::Set(channel_id),
70 connection_id: ActiveValue::Set(connection.id as i32),
71 connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
72 user_id: ActiveValue::Set(user_id),
73 replica_id: ActiveValue::Set(replica_id),
74 ..Default::default()
75 }
76 .insert(&*tx)
77 .await?;
78 collaborators.push(collaborator);
79
80 let (base_text, operations, max_operation) =
81 self.get_buffer_state(&buffer, &tx).await?;
82
83 // Save the last observed operation
84 if let Some(op) = max_operation {
85 observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
86 user_id: ActiveValue::Set(user_id),
87 buffer_id: ActiveValue::Set(buffer.id),
88 epoch: ActiveValue::Set(op.epoch),
89 lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
90 replica_id: ActiveValue::Set(op.replica_id),
91 })
92 .on_conflict(
93 OnConflict::columns([
94 observed_buffer_edits::Column::UserId,
95 observed_buffer_edits::Column::BufferId,
96 ])
97 .update_columns([
98 observed_buffer_edits::Column::Epoch,
99 observed_buffer_edits::Column::LamportTimestamp,
100 ])
101 .to_owned(),
102 )
103 .exec(&*tx)
104 .await?;
105 }
106
107 Ok(proto::JoinChannelBufferResponse {
108 buffer_id: buffer.id.to_proto(),
109 replica_id: replica_id.to_proto() as u32,
110 base_text,
111 operations,
112 epoch: buffer.epoch as u64,
113 collaborators: collaborators
114 .into_iter()
115 .map(|collaborator| proto::Collaborator {
116 peer_id: Some(collaborator.connection().into()),
117 user_id: collaborator.user_id.to_proto(),
118 replica_id: collaborator.replica_id.0 as u32,
119 })
120 .collect(),
121 })
122 })
123 .await
124 }
125
126 /// Rejoin a channel buffer (after a connection interruption)
127 pub async fn rejoin_channel_buffers(
128 &self,
129 buffers: &[proto::ChannelBufferVersion],
130 user_id: UserId,
131 connection_id: ConnectionId,
132 ) -> Result<Vec<RejoinedChannelBuffer>> {
133 self.transaction(|tx| async move {
134 let mut results = Vec::new();
135 for client_buffer in buffers {
136 let channel = self
137 .get_channel_internal(ChannelId::from_proto(client_buffer.channel_id), &*tx)
138 .await?;
139 if self
140 .check_user_is_channel_participant(&channel, user_id, &*tx)
141 .await
142 .is_err()
143 {
144 log::info!("user is not a member of channel");
145 continue;
146 }
147
148 let buffer = self.get_channel_buffer(channel.id, &*tx).await?;
149 let mut collaborators = channel_buffer_collaborator::Entity::find()
150 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel.id))
151 .all(&*tx)
152 .await?;
153
154 // If the buffer epoch hasn't changed since the client lost
155 // connection, then the client's buffer can be synchronized with
156 // the server's buffer.
157 if buffer.epoch as u64 != client_buffer.epoch {
158 log::info!("can't rejoin buffer, epoch has changed");
159 continue;
160 }
161
162 // Find the collaborator record for this user's previous lost
163 // connection. Update it with the new connection id.
164 let server_id = ServerId(connection_id.owner_id as i32);
165 let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
166 c.user_id == user_id
167 && (c.connection_lost || c.connection_server_id != server_id)
168 }) else {
169 log::info!("can't rejoin buffer, no previous collaborator found");
170 continue;
171 };
172 let old_connection_id = self_collaborator.connection();
173 *self_collaborator = channel_buffer_collaborator::ActiveModel {
174 id: ActiveValue::Unchanged(self_collaborator.id),
175 connection_id: ActiveValue::Set(connection_id.id as i32),
176 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
177 connection_lost: ActiveValue::Set(false),
178 ..Default::default()
179 }
180 .update(&*tx)
181 .await?;
182
183 let client_version = version_from_wire(&client_buffer.version);
184 let serialization_version = self
185 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
186 .await?;
187
188 let mut rows = buffer_operation::Entity::find()
189 .filter(
190 buffer_operation::Column::BufferId
191 .eq(buffer.id)
192 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
193 )
194 .stream(&*tx)
195 .await?;
196
197 // Find the server's version vector and any operations
198 // that the client has not seen.
199 let mut server_version = clock::Global::new();
200 let mut operations = Vec::new();
201 while let Some(row) = rows.next().await {
202 let row = row?;
203 let timestamp = clock::Lamport {
204 replica_id: row.replica_id as u16,
205 value: row.lamport_timestamp as u32,
206 };
207 server_version.observe(timestamp);
208 if !client_version.observed(timestamp) {
209 operations.push(proto::Operation {
210 variant: Some(operation_from_storage(row, serialization_version)?),
211 })
212 }
213 }
214
215 results.push(RejoinedChannelBuffer {
216 old_connection_id,
217 buffer: proto::RejoinedChannelBuffer {
218 channel_id: client_buffer.channel_id,
219 version: version_to_wire(&server_version),
220 operations,
221 collaborators: collaborators
222 .into_iter()
223 .map(|collaborator| proto::Collaborator {
224 peer_id: Some(collaborator.connection().into()),
225 user_id: collaborator.user_id.to_proto(),
226 replica_id: collaborator.replica_id.0 as u32,
227 })
228 .collect(),
229 },
230 });
231 }
232
233 Ok(results)
234 })
235 .await
236 }
237
238 /// Clear out any buffer collaborators who are no longer collaborating.
239 pub async fn clear_stale_channel_buffer_collaborators(
240 &self,
241 channel_id: ChannelId,
242 server_id: ServerId,
243 ) -> Result<RefreshedChannelBuffer> {
244 self.transaction(|tx| async move {
245 let db_collaborators = channel_buffer_collaborator::Entity::find()
246 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
247 .all(&*tx)
248 .await?;
249
250 let mut connection_ids = Vec::new();
251 let mut collaborators = Vec::new();
252 let mut collaborator_ids_to_remove = Vec::new();
253 for db_collaborator in &db_collaborators {
254 if !db_collaborator.connection_lost
255 && db_collaborator.connection_server_id == server_id
256 {
257 connection_ids.push(db_collaborator.connection());
258 collaborators.push(proto::Collaborator {
259 peer_id: Some(db_collaborator.connection().into()),
260 replica_id: db_collaborator.replica_id.0 as u32,
261 user_id: db_collaborator.user_id.to_proto(),
262 })
263 } else {
264 collaborator_ids_to_remove.push(db_collaborator.id);
265 }
266 }
267
268 channel_buffer_collaborator::Entity::delete_many()
269 .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
270 .exec(&*tx)
271 .await?;
272
273 Ok(RefreshedChannelBuffer {
274 connection_ids,
275 collaborators,
276 })
277 })
278 .await
279 }
280
281 /// Close the channel buffer, and stop receiving updates for it.
282 pub async fn leave_channel_buffer(
283 &self,
284 channel_id: ChannelId,
285 connection: ConnectionId,
286 ) -> Result<LeftChannelBuffer> {
287 self.transaction(|tx| async move {
288 self.leave_channel_buffer_internal(channel_id, connection, &*tx)
289 .await
290 })
291 .await
292 }
293
294 /// Close the channel buffer, and stop receiving updates for it.
295 pub async fn channel_buffer_connection_lost(
296 &self,
297 connection: ConnectionId,
298 tx: &DatabaseTransaction,
299 ) -> Result<()> {
300 channel_buffer_collaborator::Entity::update_many()
301 .filter(
302 Condition::all()
303 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
304 .add(
305 channel_buffer_collaborator::Column::ConnectionServerId
306 .eq(connection.owner_id as i32),
307 ),
308 )
309 .set(channel_buffer_collaborator::ActiveModel {
310 connection_lost: ActiveValue::set(true),
311 ..Default::default()
312 })
313 .exec(&*tx)
314 .await?;
315 Ok(())
316 }
317
318 /// Close all open channel buffers
319 pub async fn leave_channel_buffers(
320 &self,
321 connection: ConnectionId,
322 ) -> Result<Vec<LeftChannelBuffer>> {
323 self.transaction(|tx| async move {
324 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
325 enum QueryChannelIds {
326 ChannelId,
327 }
328
329 let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
330 .select_only()
331 .column(channel_buffer_collaborator::Column::ChannelId)
332 .filter(Condition::all().add(
333 channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
334 ))
335 .into_values::<_, QueryChannelIds>()
336 .all(&*tx)
337 .await?;
338
339 let mut result = Vec::new();
340 for channel_id in channel_ids {
341 let left_channel_buffer = self
342 .leave_channel_buffer_internal(channel_id, connection, &*tx)
343 .await?;
344 result.push(left_channel_buffer);
345 }
346
347 Ok(result)
348 })
349 .await
350 }
351
352 async fn leave_channel_buffer_internal(
353 &self,
354 channel_id: ChannelId,
355 connection: ConnectionId,
356 tx: &DatabaseTransaction,
357 ) -> Result<LeftChannelBuffer> {
358 let result = channel_buffer_collaborator::Entity::delete_many()
359 .filter(
360 Condition::all()
361 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
362 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
363 .add(
364 channel_buffer_collaborator::Column::ConnectionServerId
365 .eq(connection.owner_id as i32),
366 ),
367 )
368 .exec(&*tx)
369 .await?;
370 if result.rows_affected == 0 {
371 Err(anyhow!("not a collaborator on this project"))?;
372 }
373
374 let mut collaborators = Vec::new();
375 let mut connections = Vec::new();
376 let mut rows = channel_buffer_collaborator::Entity::find()
377 .filter(
378 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
379 )
380 .stream(&*tx)
381 .await?;
382 while let Some(row) = rows.next().await {
383 let row = row?;
384 let connection = row.connection();
385 connections.push(connection);
386 collaborators.push(proto::Collaborator {
387 peer_id: Some(connection.into()),
388 replica_id: row.replica_id.0 as u32,
389 user_id: row.user_id.to_proto(),
390 });
391 }
392
393 drop(rows);
394
395 if collaborators.is_empty() {
396 self.snapshot_channel_buffer(channel_id, &tx).await?;
397 }
398
399 Ok(LeftChannelBuffer {
400 channel_id,
401 collaborators,
402 connections,
403 })
404 }
405
406 pub async fn get_channel_buffer_collaborators(
407 &self,
408 channel_id: ChannelId,
409 ) -> Result<Vec<UserId>> {
410 self.transaction(|tx| async move {
411 self.get_channel_buffer_collaborators_internal(channel_id, &*tx)
412 .await
413 })
414 .await
415 }
416
417 async fn get_channel_buffer_collaborators_internal(
418 &self,
419 channel_id: ChannelId,
420 tx: &DatabaseTransaction,
421 ) -> Result<Vec<UserId>> {
422 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
423 enum QueryUserIds {
424 UserId,
425 }
426
427 let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
428 .select_only()
429 .column(channel_buffer_collaborator::Column::UserId)
430 .filter(
431 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
432 )
433 .into_values::<_, QueryUserIds>()
434 .all(&*tx)
435 .await?;
436
437 Ok(users)
438 }
439
440 pub async fn update_channel_buffer(
441 &self,
442 channel_id: ChannelId,
443 user: UserId,
444 operations: &[proto::Operation],
445 ) -> Result<(
446 Vec<ConnectionId>,
447 Vec<UserId>,
448 i32,
449 Vec<proto::VectorClockEntry>,
450 )> {
451 self.transaction(move |tx| async move {
452 let channel = self.get_channel_internal(channel_id, &*tx).await?;
453
454 let mut requires_write_permission = false;
455 for op in operations.iter() {
456 match op.variant {
457 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
458 Some(_) => requires_write_permission = true,
459 }
460 }
461 if requires_write_permission {
462 self.check_user_is_channel_member(&channel, user, &*tx)
463 .await?;
464 } else {
465 self.check_user_is_channel_participant(&channel, user, &*tx)
466 .await?;
467 }
468
469 let buffer = buffer::Entity::find()
470 .filter(buffer::Column::ChannelId.eq(channel_id))
471 .one(&*tx)
472 .await?
473 .ok_or_else(|| anyhow!("no such buffer"))?;
474
475 let serialization_version = self
476 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
477 .await?;
478
479 let operations = operations
480 .iter()
481 .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
482 .collect::<Vec<_>>();
483
484 let mut channel_members;
485 let max_version;
486
487 if !operations.is_empty() {
488 let max_operation = operations
489 .iter()
490 .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
491 .unwrap();
492
493 max_version = vec![proto::VectorClockEntry {
494 replica_id: *max_operation.replica_id.as_ref() as u32,
495 timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
496 }];
497
498 // get current channel participants and save the max operation above
499 self.save_max_operation(
500 user,
501 buffer.id,
502 buffer.epoch,
503 *max_operation.replica_id.as_ref(),
504 *max_operation.lamport_timestamp.as_ref(),
505 &*tx,
506 )
507 .await?;
508
509 channel_members = self.get_channel_participants(&channel, &*tx).await?;
510 let collaborators = self
511 .get_channel_buffer_collaborators_internal(channel_id, &*tx)
512 .await?;
513 channel_members.retain(|member| !collaborators.contains(member));
514
515 buffer_operation::Entity::insert_many(operations)
516 .on_conflict(
517 OnConflict::columns([
518 buffer_operation::Column::BufferId,
519 buffer_operation::Column::Epoch,
520 buffer_operation::Column::LamportTimestamp,
521 buffer_operation::Column::ReplicaId,
522 ])
523 .do_nothing()
524 .to_owned(),
525 )
526 .exec(&*tx)
527 .await?;
528 } else {
529 channel_members = Vec::new();
530 max_version = Vec::new();
531 }
532
533 let mut connections = Vec::new();
534 let mut rows = channel_buffer_collaborator::Entity::find()
535 .filter(
536 Condition::all()
537 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
538 )
539 .stream(&*tx)
540 .await?;
541 while let Some(row) = rows.next().await {
542 let row = row?;
543 connections.push(ConnectionId {
544 id: row.connection_id as u32,
545 owner_id: row.connection_server_id.0 as u32,
546 });
547 }
548
549 Ok((connections, channel_members, buffer.epoch, max_version))
550 })
551 .await
552 }
553
554 async fn save_max_operation(
555 &self,
556 user_id: UserId,
557 buffer_id: BufferId,
558 epoch: i32,
559 replica_id: i32,
560 lamport_timestamp: i32,
561 tx: &DatabaseTransaction,
562 ) -> Result<()> {
563 use observed_buffer_edits::Column;
564
565 observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
566 user_id: ActiveValue::Set(user_id),
567 buffer_id: ActiveValue::Set(buffer_id),
568 epoch: ActiveValue::Set(epoch),
569 replica_id: ActiveValue::Set(replica_id),
570 lamport_timestamp: ActiveValue::Set(lamport_timestamp),
571 })
572 .on_conflict(
573 OnConflict::columns([Column::UserId, Column::BufferId])
574 .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
575 .action_cond_where(
576 Condition::any().add(Column::Epoch.lt(epoch)).add(
577 Condition::all().add(Column::Epoch.eq(epoch)).add(
578 Condition::any()
579 .add(Column::LamportTimestamp.lt(lamport_timestamp))
580 .add(
581 Column::LamportTimestamp
582 .eq(lamport_timestamp)
583 .and(Column::ReplicaId.lt(replica_id)),
584 ),
585 ),
586 ),
587 )
588 .to_owned(),
589 )
590 .exec_without_returning(tx)
591 .await?;
592
593 Ok(())
594 }
595
596 async fn get_buffer_operation_serialization_version(
597 &self,
598 buffer_id: BufferId,
599 epoch: i32,
600 tx: &DatabaseTransaction,
601 ) -> Result<i32> {
602 Ok(buffer_snapshot::Entity::find()
603 .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
604 .filter(buffer_snapshot::Column::Epoch.eq(epoch))
605 .select_only()
606 .column(buffer_snapshot::Column::OperationSerializationVersion)
607 .into_values::<_, QueryOperationSerializationVersion>()
608 .one(&*tx)
609 .await?
610 .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
611 }
612
613 pub async fn get_channel_buffer(
614 &self,
615 channel_id: ChannelId,
616 tx: &DatabaseTransaction,
617 ) -> Result<buffer::Model> {
618 Ok(channel::Model {
619 id: channel_id,
620 ..Default::default()
621 }
622 .find_related(buffer::Entity)
623 .one(&*tx)
624 .await?
625 .ok_or_else(|| anyhow!("no such buffer"))?)
626 }
627
628 async fn get_buffer_state(
629 &self,
630 buffer: &buffer::Model,
631 tx: &DatabaseTransaction,
632 ) -> Result<(
633 String,
634 Vec<proto::Operation>,
635 Option<buffer_operation::Model>,
636 )> {
637 let id = buffer.id;
638 let (base_text, version) = if buffer.epoch > 0 {
639 let snapshot = buffer_snapshot::Entity::find()
640 .filter(
641 buffer_snapshot::Column::BufferId
642 .eq(id)
643 .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
644 )
645 .one(&*tx)
646 .await?
647 .ok_or_else(|| anyhow!("no such snapshot"))?;
648
649 let version = snapshot.operation_serialization_version;
650 (snapshot.text, version)
651 } else {
652 (String::new(), storage::SERIALIZATION_VERSION)
653 };
654
655 let mut rows = buffer_operation::Entity::find()
656 .filter(
657 buffer_operation::Column::BufferId
658 .eq(id)
659 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
660 )
661 .order_by_asc(buffer_operation::Column::LamportTimestamp)
662 .order_by_asc(buffer_operation::Column::ReplicaId)
663 .stream(&*tx)
664 .await?;
665
666 let mut operations = Vec::new();
667 let mut last_row = None;
668 while let Some(row) = rows.next().await {
669 let row = row?;
670 last_row = Some(buffer_operation::Model {
671 buffer_id: row.buffer_id,
672 epoch: row.epoch,
673 lamport_timestamp: row.lamport_timestamp,
674 replica_id: row.lamport_timestamp,
675 value: Default::default(),
676 });
677 operations.push(proto::Operation {
678 variant: Some(operation_from_storage(row, version)?),
679 });
680 }
681
682 Ok((base_text, operations, last_row))
683 }
684
685 async fn snapshot_channel_buffer(
686 &self,
687 channel_id: ChannelId,
688 tx: &DatabaseTransaction,
689 ) -> Result<()> {
690 let buffer = self.get_channel_buffer(channel_id, tx).await?;
691 let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
692 if operations.is_empty() {
693 return Ok(());
694 }
695
696 let mut text_buffer = text::Buffer::new(0, 0, base_text);
697 text_buffer
698 .apply_ops(operations.into_iter().filter_map(operation_from_wire))
699 .unwrap();
700
701 let base_text = text_buffer.text();
702 let epoch = buffer.epoch + 1;
703
704 buffer_snapshot::Model {
705 buffer_id: buffer.id,
706 epoch,
707 text: base_text,
708 operation_serialization_version: storage::SERIALIZATION_VERSION,
709 }
710 .into_active_model()
711 .insert(tx)
712 .await?;
713
714 buffer::ActiveModel {
715 id: ActiveValue::Unchanged(buffer.id),
716 epoch: ActiveValue::Set(epoch),
717 ..Default::default()
718 }
719 .save(tx)
720 .await?;
721
722 Ok(())
723 }
724
725 pub async fn observe_buffer_version(
726 &self,
727 buffer_id: BufferId,
728 user_id: UserId,
729 epoch: i32,
730 version: &[proto::VectorClockEntry],
731 ) -> Result<()> {
732 self.transaction(|tx| async move {
733 // For now, combine concurrent operations.
734 let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
735 return Ok(());
736 };
737 self.save_max_operation(
738 user_id,
739 buffer_id,
740 epoch,
741 component.replica_id as i32,
742 component.timestamp as i32,
743 &*tx,
744 )
745 .await?;
746 Ok(())
747 })
748 .await
749 }
750
751 pub async fn unseen_channel_buffer_changes(
752 &self,
753 user_id: UserId,
754 channel_ids: &[ChannelId],
755 tx: &DatabaseTransaction,
756 ) -> Result<Vec<proto::UnseenChannelBufferChange>> {
757 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
758 enum QueryIds {
759 ChannelId,
760 Id,
761 }
762
763 let mut channel_ids_by_buffer_id = HashMap::default();
764 let mut rows = buffer::Entity::find()
765 .filter(buffer::Column::ChannelId.is_in(channel_ids.iter().copied()))
766 .stream(&*tx)
767 .await?;
768 while let Some(row) = rows.next().await {
769 let row = row?;
770 channel_ids_by_buffer_id.insert(row.id, row.channel_id);
771 }
772 drop(rows);
773
774 let mut observed_edits_by_buffer_id = HashMap::default();
775 let mut rows = observed_buffer_edits::Entity::find()
776 .filter(observed_buffer_edits::Column::UserId.eq(user_id))
777 .filter(
778 observed_buffer_edits::Column::BufferId
779 .is_in(channel_ids_by_buffer_id.keys().copied()),
780 )
781 .stream(&*tx)
782 .await?;
783 while let Some(row) = rows.next().await {
784 let row = row?;
785 observed_edits_by_buffer_id.insert(row.buffer_id, row);
786 }
787 drop(rows);
788
789 let latest_operations = self
790 .get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
791 .await?;
792
793 let mut changes = Vec::default();
794 for latest in latest_operations {
795 if let Some(observed) = observed_edits_by_buffer_id.get(&latest.buffer_id) {
796 if (
797 observed.epoch,
798 observed.lamport_timestamp,
799 observed.replica_id,
800 ) >= (latest.epoch, latest.lamport_timestamp, latest.replica_id)
801 {
802 continue;
803 }
804 }
805
806 if let Some(channel_id) = channel_ids_by_buffer_id.get(&latest.buffer_id) {
807 changes.push(proto::UnseenChannelBufferChange {
808 channel_id: channel_id.to_proto(),
809 epoch: latest.epoch as u64,
810 version: vec![proto::VectorClockEntry {
811 replica_id: latest.replica_id as u32,
812 timestamp: latest.lamport_timestamp as u32,
813 }],
814 });
815 }
816 }
817
818 Ok(changes)
819 }
820
821 /// Returns the latest operations for the buffers with the specified IDs.
822 pub async fn get_latest_operations_for_buffers(
823 &self,
824 buffer_ids: impl IntoIterator<Item = BufferId>,
825 tx: &DatabaseTransaction,
826 ) -> Result<Vec<buffer_operation::Model>> {
827 let mut values = String::new();
828 for id in buffer_ids {
829 if !values.is_empty() {
830 values.push_str(", ");
831 }
832 write!(&mut values, "({})", id).unwrap();
833 }
834
835 if values.is_empty() {
836 return Ok(Vec::default());
837 }
838
839 let sql = format!(
840 r#"
841 SELECT
842 *
843 FROM
844 (
845 SELECT
846 *,
847 row_number() OVER (
848 PARTITION BY buffer_id
849 ORDER BY
850 epoch DESC,
851 lamport_timestamp DESC,
852 replica_id DESC
853 ) as row_number
854 FROM buffer_operations
855 WHERE
856 buffer_id in ({values})
857 ) AS last_operations
858 WHERE
859 row_number = 1
860 "#,
861 );
862
863 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
864 Ok(buffer_operation::Entity::find()
865 .from_raw_sql(stmt)
866 .all(&*tx)
867 .await?)
868 }
869}
870
871fn operation_to_storage(
872 operation: &proto::Operation,
873 buffer: &buffer::Model,
874 _format: i32,
875) -> Option<buffer_operation::ActiveModel> {
876 let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
877 proto::operation::Variant::Edit(operation) => (
878 operation.replica_id,
879 operation.lamport_timestamp,
880 storage::Operation {
881 version: version_to_storage(&operation.version),
882 is_undo: false,
883 edit_ranges: operation
884 .ranges
885 .iter()
886 .map(|range| storage::Range {
887 start: range.start,
888 end: range.end,
889 })
890 .collect(),
891 edit_texts: operation.new_text.clone(),
892 undo_counts: Vec::new(),
893 },
894 ),
895 proto::operation::Variant::Undo(operation) => (
896 operation.replica_id,
897 operation.lamport_timestamp,
898 storage::Operation {
899 version: version_to_storage(&operation.version),
900 is_undo: true,
901 edit_ranges: Vec::new(),
902 edit_texts: Vec::new(),
903 undo_counts: operation
904 .counts
905 .iter()
906 .map(|entry| storage::UndoCount {
907 replica_id: entry.replica_id,
908 lamport_timestamp: entry.lamport_timestamp,
909 count: entry.count,
910 })
911 .collect(),
912 },
913 ),
914 _ => None?,
915 };
916
917 Some(buffer_operation::ActiveModel {
918 buffer_id: ActiveValue::Set(buffer.id),
919 epoch: ActiveValue::Set(buffer.epoch),
920 replica_id: ActiveValue::Set(replica_id as i32),
921 lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
922 value: ActiveValue::Set(value.encode_to_vec()),
923 })
924}
925
926fn operation_from_storage(
927 row: buffer_operation::Model,
928 _format_version: i32,
929) -> Result<proto::operation::Variant, Error> {
930 let operation =
931 storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
932 let version = version_from_storage(&operation.version);
933 Ok(if operation.is_undo {
934 proto::operation::Variant::Undo(proto::operation::Undo {
935 replica_id: row.replica_id as u32,
936 lamport_timestamp: row.lamport_timestamp as u32,
937 version,
938 counts: operation
939 .undo_counts
940 .iter()
941 .map(|entry| proto::UndoCount {
942 replica_id: entry.replica_id,
943 lamport_timestamp: entry.lamport_timestamp,
944 count: entry.count,
945 })
946 .collect(),
947 })
948 } else {
949 proto::operation::Variant::Edit(proto::operation::Edit {
950 replica_id: row.replica_id as u32,
951 lamport_timestamp: row.lamport_timestamp as u32,
952 version,
953 ranges: operation
954 .edit_ranges
955 .into_iter()
956 .map(|range| proto::Range {
957 start: range.start,
958 end: range.end,
959 })
960 .collect(),
961 new_text: operation.edit_texts,
962 })
963 })
964}
965
966fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
967 version
968 .iter()
969 .map(|entry| storage::VectorClockEntry {
970 replica_id: entry.replica_id,
971 timestamp: entry.timestamp,
972 })
973 .collect()
974}
975
976fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
977 version
978 .iter()
979 .map(|entry| proto::VectorClockEntry {
980 replica_id: entry.replica_id,
981 timestamp: entry.timestamp,
982 })
983 .collect()
984}
985
986// This is currently a manual copy of the deserialization code in the client's language crate
987pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
988 match operation.variant? {
989 proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
990 timestamp: clock::Lamport {
991 replica_id: edit.replica_id as text::ReplicaId,
992 value: edit.lamport_timestamp,
993 },
994 version: version_from_wire(&edit.version),
995 ranges: edit
996 .ranges
997 .into_iter()
998 .map(|range| {
999 text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
1000 })
1001 .collect(),
1002 new_text: edit.new_text.into_iter().map(Arc::from).collect(),
1003 })),
1004 proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
1005 timestamp: clock::Lamport {
1006 replica_id: undo.replica_id as text::ReplicaId,
1007 value: undo.lamport_timestamp,
1008 },
1009 version: version_from_wire(&undo.version),
1010 counts: undo
1011 .counts
1012 .into_iter()
1013 .map(|c| {
1014 (
1015 clock::Lamport {
1016 replica_id: c.replica_id as text::ReplicaId,
1017 value: c.lamport_timestamp,
1018 },
1019 c.count,
1020 )
1021 })
1022 .collect(),
1023 })),
1024 _ => None,
1025 }
1026}
1027
1028fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
1029 let mut version = clock::Global::new();
1030 for entry in message {
1031 version.observe(clock::Lamport {
1032 replica_id: entry.replica_id as text::ReplicaId,
1033 value: entry.timestamp,
1034 });
1035 }
1036 version
1037}
1038
1039fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
1040 let mut message = Vec::new();
1041 for entry in version.iter() {
1042 message.push(proto::VectorClockEntry {
1043 replica_id: entry.replica_id as u32,
1044 timestamp: entry.value,
1045 });
1046 }
1047 message
1048}
1049
1050#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
1051enum QueryOperationSerializationVersion {
1052 OperationSerializationVersion,
1053}
1054
1055mod storage {
1056 #![allow(non_snake_case)]
1057 use prost::Message;
1058 pub const SERIALIZATION_VERSION: i32 = 1;
1059
1060 #[derive(Message)]
1061 pub struct Operation {
1062 #[prost(message, repeated, tag = "2")]
1063 pub version: Vec<VectorClockEntry>,
1064 #[prost(bool, tag = "3")]
1065 pub is_undo: bool,
1066 #[prost(message, repeated, tag = "4")]
1067 pub edit_ranges: Vec<Range>,
1068 #[prost(string, repeated, tag = "5")]
1069 pub edit_texts: Vec<String>,
1070 #[prost(message, repeated, tag = "6")]
1071 pub undo_counts: Vec<UndoCount>,
1072 }
1073
1074 #[derive(Message)]
1075 pub struct VectorClockEntry {
1076 #[prost(uint32, tag = "1")]
1077 pub replica_id: u32,
1078 #[prost(uint32, tag = "2")]
1079 pub timestamp: u32,
1080 }
1081
1082 #[derive(Message)]
1083 pub struct Range {
1084 #[prost(uint64, tag = "1")]
1085 pub start: u64,
1086 #[prost(uint64, tag = "2")]
1087 pub end: u64,
1088 }
1089
1090 #[derive(Message)]
1091 pub struct UndoCount {
1092 #[prost(uint32, tag = "1")]
1093 pub replica_id: u32,
1094 #[prost(uint32, tag = "2")]
1095 pub lamport_timestamp: u32,
1096 #[prost(uint32, tag = "3")]
1097 pub count: u32,
1098 }
1099}