1use super::*;
2use prost::Message;
3use text::{EditOperation, UndoOperation};
4
5impl Database {
6 pub async fn join_channel_buffer(
7 &self,
8 channel_id: ChannelId,
9 user_id: UserId,
10 connection: ConnectionId,
11 ) -> Result<proto::JoinChannelBufferResponse> {
12 self.transaction(|tx| async move {
13 self.check_user_is_channel_member(channel_id, user_id, &tx)
14 .await?;
15
16 let buffer = channel::Model {
17 id: channel_id,
18 ..Default::default()
19 }
20 .find_related(buffer::Entity)
21 .one(&*tx)
22 .await?;
23
24 let buffer = if let Some(buffer) = buffer {
25 buffer
26 } else {
27 let buffer = buffer::ActiveModel {
28 channel_id: ActiveValue::Set(channel_id),
29 ..Default::default()
30 }
31 .insert(&*tx)
32 .await?;
33 buffer_snapshot::ActiveModel {
34 buffer_id: ActiveValue::Set(buffer.id),
35 epoch: ActiveValue::Set(0),
36 text: ActiveValue::Set(String::new()),
37 operation_serialization_version: ActiveValue::Set(
38 storage::SERIALIZATION_VERSION,
39 ),
40 }
41 .insert(&*tx)
42 .await?;
43 buffer
44 };
45
46 // Join the collaborators
47 let mut collaborators = channel_buffer_collaborator::Entity::find()
48 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
49 .all(&*tx)
50 .await?;
51 let replica_ids = collaborators
52 .iter()
53 .map(|c| c.replica_id)
54 .collect::<HashSet<_>>();
55 let mut replica_id = ReplicaId(0);
56 while replica_ids.contains(&replica_id) {
57 replica_id.0 += 1;
58 }
59 let collaborator = channel_buffer_collaborator::ActiveModel {
60 channel_id: ActiveValue::Set(channel_id),
61 connection_id: ActiveValue::Set(connection.id as i32),
62 connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
63 user_id: ActiveValue::Set(user_id),
64 replica_id: ActiveValue::Set(replica_id),
65 ..Default::default()
66 }
67 .insert(&*tx)
68 .await?;
69 collaborators.push(collaborator);
70
71 let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
72
73 Ok(proto::JoinChannelBufferResponse {
74 buffer_id: buffer.id.to_proto(),
75 replica_id: replica_id.to_proto() as u32,
76 base_text,
77 operations,
78 epoch: buffer.epoch as u64,
79 collaborators: collaborators
80 .into_iter()
81 .map(|collaborator| proto::Collaborator {
82 peer_id: Some(collaborator.connection().into()),
83 user_id: collaborator.user_id.to_proto(),
84 replica_id: collaborator.replica_id.0 as u32,
85 })
86 .collect(),
87 })
88 })
89 .await
90 }
91
92 pub async fn rejoin_channel_buffers(
93 &self,
94 buffers: &[proto::ChannelBufferVersion],
95 user_id: UserId,
96 connection_id: ConnectionId,
97 ) -> Result<Vec<RejoinedChannelBuffer>> {
98 self.transaction(|tx| async move {
99 let mut results = Vec::new();
100 for client_buffer in buffers {
101 let channel_id = ChannelId::from_proto(client_buffer.channel_id);
102 if self
103 .check_user_is_channel_member(channel_id, user_id, &*tx)
104 .await
105 .is_err()
106 {
107 log::info!("user is not a member of channel");
108 continue;
109 }
110
111 let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
112 let mut collaborators = channel_buffer_collaborator::Entity::find()
113 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
114 .all(&*tx)
115 .await?;
116
117 // If the buffer epoch hasn't changed since the client lost
118 // connection, then the client's buffer can be syncronized with
119 // the server's buffer.
120 if buffer.epoch as u64 != client_buffer.epoch {
121 continue;
122 }
123
124 // Find the collaborator record for this user's previous lost
125 // connection. Update it with the new connection id.
126 let server_id = ServerId(connection_id.owner_id as i32);
127 let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
128 c.user_id == user_id
129 && (c.connection_lost || c.connection_server_id != server_id)
130 }) else {
131 continue;
132 };
133 let old_connection_id = self_collaborator.connection();
134 *self_collaborator = channel_buffer_collaborator::ActiveModel {
135 id: ActiveValue::Unchanged(self_collaborator.id),
136 connection_id: ActiveValue::Set(connection_id.id as i32),
137 connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
138 connection_lost: ActiveValue::Set(false),
139 ..Default::default()
140 }
141 .update(&*tx)
142 .await?;
143
144 let client_version = version_from_wire(&client_buffer.version);
145 let serialization_version = self
146 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
147 .await?;
148
149 let mut rows = buffer_operation::Entity::find()
150 .filter(
151 buffer_operation::Column::BufferId
152 .eq(buffer.id)
153 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
154 )
155 .stream(&*tx)
156 .await?;
157
158 // Find the server's version vector and any operations
159 // that the client has not seen.
160 let mut server_version = clock::Global::new();
161 let mut operations = Vec::new();
162 while let Some(row) = rows.next().await {
163 let row = row?;
164 let timestamp = clock::Lamport {
165 replica_id: row.replica_id as u16,
166 value: row.lamport_timestamp as u32,
167 };
168 server_version.observe(timestamp);
169 if !client_version.observed(timestamp) {
170 operations.push(proto::Operation {
171 variant: Some(operation_from_storage(row, serialization_version)?),
172 })
173 }
174 }
175
176 results.push(RejoinedChannelBuffer {
177 old_connection_id,
178 buffer: proto::RejoinedChannelBuffer {
179 channel_id: client_buffer.channel_id,
180 version: version_to_wire(&server_version),
181 operations,
182 collaborators: collaborators
183 .into_iter()
184 .map(|collaborator| proto::Collaborator {
185 peer_id: Some(collaborator.connection().into()),
186 user_id: collaborator.user_id.to_proto(),
187 replica_id: collaborator.replica_id.0 as u32,
188 })
189 .collect(),
190 },
191 });
192 }
193
194 Ok(results)
195 })
196 .await
197 }
198
199 pub async fn refresh_channel_buffer(
200 &self,
201 channel_id: ChannelId,
202 server_id: ServerId,
203 ) -> Result<RefreshedChannelBuffer> {
204 self.transaction(|tx| async move {
205 let mut connection_ids = Vec::new();
206 let mut removed_collaborators = Vec::new();
207
208 // TODO
209
210 Ok(RefreshedChannelBuffer {
211 connection_ids,
212 removed_collaborators,
213 })
214 })
215 .await
216 }
217
218 pub async fn leave_channel_buffer(
219 &self,
220 channel_id: ChannelId,
221 connection: ConnectionId,
222 ) -> Result<Vec<ConnectionId>> {
223 self.transaction(|tx| async move {
224 self.leave_channel_buffer_internal(channel_id, connection, &*tx)
225 .await
226 })
227 .await
228 }
229
230 pub async fn leave_channel_buffers(
231 &self,
232 connection: ConnectionId,
233 ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
234 self.transaction(|tx| async move {
235 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
236 enum QueryChannelIds {
237 ChannelId,
238 }
239
240 let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
241 .select_only()
242 .column(channel_buffer_collaborator::Column::ChannelId)
243 .filter(Condition::all().add(
244 channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
245 ))
246 .into_values::<_, QueryChannelIds>()
247 .all(&*tx)
248 .await?;
249
250 let mut result = Vec::new();
251 for channel_id in channel_ids {
252 let collaborators = self
253 .leave_channel_buffer_internal(channel_id, connection, &*tx)
254 .await?;
255 result.push((channel_id, collaborators));
256 }
257
258 Ok(result)
259 })
260 .await
261 }
262
263 pub async fn leave_channel_buffer_internal(
264 &self,
265 channel_id: ChannelId,
266 connection: ConnectionId,
267 tx: &DatabaseTransaction,
268 ) -> Result<Vec<ConnectionId>> {
269 let result = channel_buffer_collaborator::Entity::delete_many()
270 .filter(
271 Condition::all()
272 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
273 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
274 .add(
275 channel_buffer_collaborator::Column::ConnectionServerId
276 .eq(connection.owner_id as i32),
277 ),
278 )
279 .exec(&*tx)
280 .await?;
281 if result.rows_affected == 0 {
282 Err(anyhow!("not a collaborator on this project"))?;
283 }
284
285 let mut connections = Vec::new();
286 let mut rows = channel_buffer_collaborator::Entity::find()
287 .filter(
288 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
289 )
290 .stream(&*tx)
291 .await?;
292 while let Some(row) = rows.next().await {
293 let row = row?;
294 connections.push(ConnectionId {
295 id: row.connection_id as u32,
296 owner_id: row.connection_server_id.0 as u32,
297 });
298 }
299
300 drop(rows);
301
302 if connections.is_empty() {
303 self.snapshot_channel_buffer(channel_id, &tx).await?;
304 }
305
306 Ok(connections)
307 }
308
309 pub async fn get_channel_buffer_collaborators(
310 &self,
311 channel_id: ChannelId,
312 ) -> Result<Vec<UserId>> {
313 self.transaction(|tx| async move {
314 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
315 enum QueryUserIds {
316 UserId,
317 }
318
319 let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
320 .select_only()
321 .column(channel_buffer_collaborator::Column::UserId)
322 .filter(
323 Condition::all()
324 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
325 )
326 .into_values::<_, QueryUserIds>()
327 .all(&*tx)
328 .await?;
329
330 Ok(users)
331 })
332 .await
333 }
334
335 pub async fn update_channel_buffer(
336 &self,
337 channel_id: ChannelId,
338 user: UserId,
339 operations: &[proto::Operation],
340 ) -> Result<Vec<ConnectionId>> {
341 self.transaction(move |tx| async move {
342 self.check_user_is_channel_member(channel_id, user, &*tx)
343 .await?;
344
345 let buffer = buffer::Entity::find()
346 .filter(buffer::Column::ChannelId.eq(channel_id))
347 .one(&*tx)
348 .await?
349 .ok_or_else(|| anyhow!("no such buffer"))?;
350
351 let serialization_version = self
352 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
353 .await?;
354
355 let operations = operations
356 .iter()
357 .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
358 .collect::<Vec<_>>();
359 if !operations.is_empty() {
360 buffer_operation::Entity::insert_many(operations)
361 .exec(&*tx)
362 .await?;
363 }
364
365 let mut connections = Vec::new();
366 let mut rows = channel_buffer_collaborator::Entity::find()
367 .filter(
368 Condition::all()
369 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
370 )
371 .stream(&*tx)
372 .await?;
373 while let Some(row) = rows.next().await {
374 let row = row?;
375 connections.push(ConnectionId {
376 id: row.connection_id as u32,
377 owner_id: row.connection_server_id.0 as u32,
378 });
379 }
380
381 Ok(connections)
382 })
383 .await
384 }
385
386 async fn get_buffer_operation_serialization_version(
387 &self,
388 buffer_id: BufferId,
389 epoch: i32,
390 tx: &DatabaseTransaction,
391 ) -> Result<i32> {
392 Ok(buffer_snapshot::Entity::find()
393 .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
394 .filter(buffer_snapshot::Column::Epoch.eq(epoch))
395 .select_only()
396 .column(buffer_snapshot::Column::OperationSerializationVersion)
397 .into_values::<_, QueryOperationSerializationVersion>()
398 .one(&*tx)
399 .await?
400 .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
401 }
402
403 async fn get_channel_buffer(
404 &self,
405 channel_id: ChannelId,
406 tx: &DatabaseTransaction,
407 ) -> Result<buffer::Model> {
408 Ok(channel::Model {
409 id: channel_id,
410 ..Default::default()
411 }
412 .find_related(buffer::Entity)
413 .one(&*tx)
414 .await?
415 .ok_or_else(|| anyhow!("no such buffer"))?)
416 }
417
418 async fn get_buffer_state(
419 &self,
420 buffer: &buffer::Model,
421 tx: &DatabaseTransaction,
422 ) -> Result<(String, Vec<proto::Operation>)> {
423 let id = buffer.id;
424 let (base_text, version) = if buffer.epoch > 0 {
425 let snapshot = buffer_snapshot::Entity::find()
426 .filter(
427 buffer_snapshot::Column::BufferId
428 .eq(id)
429 .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
430 )
431 .one(&*tx)
432 .await?
433 .ok_or_else(|| anyhow!("no such snapshot"))?;
434
435 let version = snapshot.operation_serialization_version;
436 (snapshot.text, version)
437 } else {
438 (String::new(), storage::SERIALIZATION_VERSION)
439 };
440
441 let mut rows = buffer_operation::Entity::find()
442 .filter(
443 buffer_operation::Column::BufferId
444 .eq(id)
445 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
446 )
447 .stream(&*tx)
448 .await?;
449 let mut operations = Vec::new();
450 while let Some(row) = rows.next().await {
451 operations.push(proto::Operation {
452 variant: Some(operation_from_storage(row?, version)?),
453 })
454 }
455
456 Ok((base_text, operations))
457 }
458
459 async fn snapshot_channel_buffer(
460 &self,
461 channel_id: ChannelId,
462 tx: &DatabaseTransaction,
463 ) -> Result<()> {
464 let buffer = self.get_channel_buffer(channel_id, tx).await?;
465 let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
466 if operations.is_empty() {
467 return Ok(());
468 }
469
470 let mut text_buffer = text::Buffer::new(0, 0, base_text);
471 text_buffer
472 .apply_ops(operations.into_iter().filter_map(operation_from_wire))
473 .unwrap();
474
475 let base_text = text_buffer.text();
476 let epoch = buffer.epoch + 1;
477
478 buffer_snapshot::Model {
479 buffer_id: buffer.id,
480 epoch,
481 text: base_text,
482 operation_serialization_version: storage::SERIALIZATION_VERSION,
483 }
484 .into_active_model()
485 .insert(tx)
486 .await?;
487
488 buffer::ActiveModel {
489 id: ActiveValue::Unchanged(buffer.id),
490 epoch: ActiveValue::Set(epoch),
491 ..Default::default()
492 }
493 .save(tx)
494 .await?;
495
496 Ok(())
497 }
498}
499
500fn operation_to_storage(
501 operation: &proto::Operation,
502 buffer: &buffer::Model,
503 _format: i32,
504) -> Option<buffer_operation::ActiveModel> {
505 let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
506 proto::operation::Variant::Edit(operation) => (
507 operation.replica_id,
508 operation.lamport_timestamp,
509 storage::Operation {
510 version: version_to_storage(&operation.version),
511 is_undo: false,
512 edit_ranges: operation
513 .ranges
514 .iter()
515 .map(|range| storage::Range {
516 start: range.start,
517 end: range.end,
518 })
519 .collect(),
520 edit_texts: operation.new_text.clone(),
521 undo_counts: Vec::new(),
522 },
523 ),
524 proto::operation::Variant::Undo(operation) => (
525 operation.replica_id,
526 operation.lamport_timestamp,
527 storage::Operation {
528 version: version_to_storage(&operation.version),
529 is_undo: true,
530 edit_ranges: Vec::new(),
531 edit_texts: Vec::new(),
532 undo_counts: operation
533 .counts
534 .iter()
535 .map(|entry| storage::UndoCount {
536 replica_id: entry.replica_id,
537 lamport_timestamp: entry.lamport_timestamp,
538 count: entry.count,
539 })
540 .collect(),
541 },
542 ),
543 _ => None?,
544 };
545
546 Some(buffer_operation::ActiveModel {
547 buffer_id: ActiveValue::Set(buffer.id),
548 epoch: ActiveValue::Set(buffer.epoch),
549 replica_id: ActiveValue::Set(replica_id as i32),
550 lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
551 value: ActiveValue::Set(value.encode_to_vec()),
552 })
553}
554
555fn operation_from_storage(
556 row: buffer_operation::Model,
557 _format_version: i32,
558) -> Result<proto::operation::Variant, Error> {
559 let operation =
560 storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
561 let version = version_from_storage(&operation.version);
562 Ok(if operation.is_undo {
563 proto::operation::Variant::Undo(proto::operation::Undo {
564 replica_id: row.replica_id as u32,
565 lamport_timestamp: row.lamport_timestamp as u32,
566 version,
567 counts: operation
568 .undo_counts
569 .iter()
570 .map(|entry| proto::UndoCount {
571 replica_id: entry.replica_id,
572 lamport_timestamp: entry.lamport_timestamp,
573 count: entry.count,
574 })
575 .collect(),
576 })
577 } else {
578 proto::operation::Variant::Edit(proto::operation::Edit {
579 replica_id: row.replica_id as u32,
580 lamport_timestamp: row.lamport_timestamp as u32,
581 version,
582 ranges: operation
583 .edit_ranges
584 .into_iter()
585 .map(|range| proto::Range {
586 start: range.start,
587 end: range.end,
588 })
589 .collect(),
590 new_text: operation.edit_texts,
591 })
592 })
593}
594
595fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
596 version
597 .iter()
598 .map(|entry| storage::VectorClockEntry {
599 replica_id: entry.replica_id,
600 timestamp: entry.timestamp,
601 })
602 .collect()
603}
604
605fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
606 version
607 .iter()
608 .map(|entry| proto::VectorClockEntry {
609 replica_id: entry.replica_id,
610 timestamp: entry.timestamp,
611 })
612 .collect()
613}
614
615// This is currently a manual copy of the deserialization code in the client's langauge crate
616pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
617 match operation.variant? {
618 proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
619 timestamp: clock::Lamport {
620 replica_id: edit.replica_id as text::ReplicaId,
621 value: edit.lamport_timestamp,
622 },
623 version: version_from_wire(&edit.version),
624 ranges: edit
625 .ranges
626 .into_iter()
627 .map(|range| {
628 text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
629 })
630 .collect(),
631 new_text: edit.new_text.into_iter().map(Arc::from).collect(),
632 })),
633 proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
634 timestamp: clock::Lamport {
635 replica_id: undo.replica_id as text::ReplicaId,
636 value: undo.lamport_timestamp,
637 },
638 version: version_from_wire(&undo.version),
639 counts: undo
640 .counts
641 .into_iter()
642 .map(|c| {
643 (
644 clock::Lamport {
645 replica_id: c.replica_id as text::ReplicaId,
646 value: c.lamport_timestamp,
647 },
648 c.count,
649 )
650 })
651 .collect(),
652 })),
653 _ => None,
654 }
655}
656
657fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
658 let mut version = clock::Global::new();
659 for entry in message {
660 version.observe(clock::Lamport {
661 replica_id: entry.replica_id as text::ReplicaId,
662 value: entry.timestamp,
663 });
664 }
665 version
666}
667
668fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
669 let mut message = Vec::new();
670 for entry in version.iter() {
671 message.push(proto::VectorClockEntry {
672 replica_id: entry.replica_id as u32,
673 timestamp: entry.value,
674 });
675 }
676 message
677}
678
679#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
680enum QueryOperationSerializationVersion {
681 OperationSerializationVersion,
682}
683
684mod storage {
685 #![allow(non_snake_case)]
686 use prost::Message;
687 pub const SERIALIZATION_VERSION: i32 = 1;
688
689 #[derive(Message)]
690 pub struct Operation {
691 #[prost(message, repeated, tag = "2")]
692 pub version: Vec<VectorClockEntry>,
693 #[prost(bool, tag = "3")]
694 pub is_undo: bool,
695 #[prost(message, repeated, tag = "4")]
696 pub edit_ranges: Vec<Range>,
697 #[prost(string, repeated, tag = "5")]
698 pub edit_texts: Vec<String>,
699 #[prost(message, repeated, tag = "6")]
700 pub undo_counts: Vec<UndoCount>,
701 }
702
703 #[derive(Message)]
704 pub struct VectorClockEntry {
705 #[prost(uint32, tag = "1")]
706 pub replica_id: u32,
707 #[prost(uint32, tag = "2")]
708 pub timestamp: u32,
709 }
710
711 #[derive(Message)]
712 pub struct Range {
713 #[prost(uint64, tag = "1")]
714 pub start: u64,
715 #[prost(uint64, tag = "2")]
716 pub end: u64,
717 }
718
719 #[derive(Message)]
720 pub struct UndoCount {
721 #[prost(uint32, tag = "1")]
722 pub replica_id: u32,
723 #[prost(uint32, tag = "2")]
724 pub lamport_timestamp: u32,
725 #[prost(uint32, tag = "3")]
726 pub count: u32,
727 }
728}