1use super::*;
2use prost::Message;
3use text::{EditOperation, UndoOperation};
4
5impl Database {
6 pub async fn join_channel_buffer(
7 &self,
8 channel_id: ChannelId,
9 user_id: UserId,
10 connection: ConnectionId,
11 ) -> Result<proto::JoinChannelBufferResponse> {
12 self.transaction(|tx| async move {
13 self.check_user_is_channel_member(channel_id, user_id, &tx)
14 .await?;
15
16 let buffer = channel::Model {
17 id: channel_id,
18 ..Default::default()
19 }
20 .find_related(buffer::Entity)
21 .one(&*tx)
22 .await?;
23
24 let buffer = if let Some(buffer) = buffer {
25 buffer
26 } else {
27 let buffer = buffer::ActiveModel {
28 channel_id: ActiveValue::Set(channel_id),
29 ..Default::default()
30 }
31 .insert(&*tx)
32 .await?;
33 buffer_snapshot::ActiveModel {
34 buffer_id: ActiveValue::Set(buffer.id),
35 epoch: ActiveValue::Set(0),
36 text: ActiveValue::Set(String::new()),
37 operation_serialization_version: ActiveValue::Set(
38 storage::SERIALIZATION_VERSION,
39 ),
40 }
41 .insert(&*tx)
42 .await?;
43 buffer
44 };
45
46 // Join the collaborators
47 let mut collaborators = channel_buffer_collaborator::Entity::find()
48 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
49 .all(&*tx)
50 .await?;
51 let replica_ids = collaborators
52 .iter()
53 .map(|c| c.replica_id)
54 .collect::<HashSet<_>>();
55 let mut replica_id = ReplicaId(0);
56 while replica_ids.contains(&replica_id) {
57 replica_id.0 += 1;
58 }
59 let collaborator = channel_buffer_collaborator::ActiveModel {
60 channel_id: ActiveValue::Set(channel_id),
61 connection_id: ActiveValue::Set(connection.id as i32),
62 connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
63 user_id: ActiveValue::Set(user_id),
64 replica_id: ActiveValue::Set(replica_id),
65 ..Default::default()
66 }
67 .insert(&*tx)
68 .await?;
69 collaborators.push(collaborator);
70
71 let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
72
73 Ok(proto::JoinChannelBufferResponse {
74 buffer_id: buffer.id.to_proto(),
75 replica_id: replica_id.to_proto() as u32,
76 base_text,
77 operations,
78 epoch: buffer.epoch as u64,
79 collaborators: collaborators
80 .into_iter()
81 .map(|collaborator| proto::Collaborator {
82 peer_id: Some(collaborator.connection().into()),
83 user_id: collaborator.user_id.to_proto(),
84 replica_id: collaborator.replica_id.0 as u32,
85 })
86 .collect(),
87 })
88 })
89 .await
90 }
91
92 pub async fn rejoin_channel_buffers(
93 &self,
94 buffers: &[proto::ChannelBufferVersion],
95 user_id: UserId,
96 connection_id: ConnectionId,
97 ) -> Result<proto::RejoinChannelBuffersResponse> {
98 self.transaction(|tx| async move {
99 let mut response = proto::RejoinChannelBuffersResponse::default();
100 for client_buffer in buffers {
101 let channel_id = ChannelId::from_proto(client_buffer.channel_id);
102 if self
103 .check_user_is_channel_member(channel_id, user_id, &*tx)
104 .await
105 .is_err()
106 {
107 log::info!("user is not a member of channel");
108 continue;
109 }
110
111 let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
112 let mut collaborators = channel_buffer_collaborator::Entity::find()
113 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
114 .all(&*tx)
115 .await?;
116
117 // If the buffer epoch hasn't changed since the client lost
118 // connection, then the client's buffer can be syncronized with
119 // the server's buffer.
120 if buffer.epoch as u64 != client_buffer.epoch {
121 continue;
122 }
123
124 // If there is still a disconnected collaborator for the user,
125 // update the connection associated with that collaborator, and reuse
126 // that replica id.
127 if let Some(ix) = collaborators
128 .iter()
129 .position(|c| c.user_id == user_id && c.connection_lost)
130 {
131 let self_collaborator = &mut collaborators[ix];
132 *self_collaborator = channel_buffer_collaborator::ActiveModel {
133 id: ActiveValue::Unchanged(self_collaborator.id),
134 connection_id: ActiveValue::Set(connection_id.id as i32),
135 connection_server_id: ActiveValue::Set(ServerId(
136 connection_id.owner_id as i32,
137 )),
138 connection_lost: ActiveValue::Set(false),
139 ..Default::default()
140 }
141 .update(&*tx)
142 .await?;
143 } else {
144 continue;
145 }
146
147 let client_version = version_from_wire(&client_buffer.version);
148 let serialization_version = self
149 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
150 .await?;
151
152 let mut rows = buffer_operation::Entity::find()
153 .filter(
154 buffer_operation::Column::BufferId
155 .eq(buffer.id)
156 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
157 )
158 .stream(&*tx)
159 .await?;
160
161 // Find the server's version vector and any operations
162 // that the client has not seen.
163 let mut server_version = clock::Global::new();
164 let mut operations = Vec::new();
165 while let Some(row) = rows.next().await {
166 let row = row?;
167 let timestamp = clock::Lamport {
168 replica_id: row.replica_id as u16,
169 value: row.lamport_timestamp as u32,
170 };
171 server_version.observe(timestamp);
172 if !client_version.observed(timestamp) {
173 operations.push(proto::Operation {
174 variant: Some(operation_from_storage(row, serialization_version)?),
175 })
176 }
177 }
178
179 response.buffers.push(proto::RejoinedChannelBuffer {
180 channel_id: client_buffer.channel_id,
181 version: version_to_wire(&server_version),
182 operations,
183 collaborators: collaborators
184 .into_iter()
185 .map(|collaborator| proto::Collaborator {
186 peer_id: Some(collaborator.connection().into()),
187 user_id: collaborator.user_id.to_proto(),
188 replica_id: collaborator.replica_id.0 as u32,
189 })
190 .collect(),
191 });
192 }
193
194 Ok(response)
195 })
196 .await
197 }
198
199 pub async fn leave_channel_buffer(
200 &self,
201 channel_id: ChannelId,
202 connection: ConnectionId,
203 ) -> Result<Vec<ConnectionId>> {
204 self.transaction(|tx| async move {
205 self.leave_channel_buffer_internal(channel_id, connection, &*tx)
206 .await
207 })
208 .await
209 }
210
211 pub async fn leave_channel_buffers(
212 &self,
213 connection: ConnectionId,
214 ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
215 self.transaction(|tx| async move {
216 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
217 enum QueryChannelIds {
218 ChannelId,
219 }
220
221 let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
222 .select_only()
223 .column(channel_buffer_collaborator::Column::ChannelId)
224 .filter(Condition::all().add(
225 channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
226 ))
227 .into_values::<_, QueryChannelIds>()
228 .all(&*tx)
229 .await?;
230
231 let mut result = Vec::new();
232 for channel_id in channel_ids {
233 let collaborators = self
234 .leave_channel_buffer_internal(channel_id, connection, &*tx)
235 .await?;
236 result.push((channel_id, collaborators));
237 }
238
239 Ok(result)
240 })
241 .await
242 }
243
244 pub async fn leave_channel_buffer_internal(
245 &self,
246 channel_id: ChannelId,
247 connection: ConnectionId,
248 tx: &DatabaseTransaction,
249 ) -> Result<Vec<ConnectionId>> {
250 let result = channel_buffer_collaborator::Entity::delete_many()
251 .filter(
252 Condition::all()
253 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
254 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
255 .add(
256 channel_buffer_collaborator::Column::ConnectionServerId
257 .eq(connection.owner_id as i32),
258 ),
259 )
260 .exec(&*tx)
261 .await?;
262 if result.rows_affected == 0 {
263 Err(anyhow!("not a collaborator on this project"))?;
264 }
265
266 let mut connections = Vec::new();
267 let mut rows = channel_buffer_collaborator::Entity::find()
268 .filter(
269 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
270 )
271 .stream(&*tx)
272 .await?;
273 while let Some(row) = rows.next().await {
274 let row = row?;
275 connections.push(ConnectionId {
276 id: row.connection_id as u32,
277 owner_id: row.connection_server_id.0 as u32,
278 });
279 }
280
281 drop(rows);
282
283 if connections.is_empty() {
284 self.snapshot_channel_buffer(channel_id, &tx).await?;
285 }
286
287 Ok(connections)
288 }
289
290 pub async fn get_channel_buffer_collaborators(
291 &self,
292 channel_id: ChannelId,
293 ) -> Result<Vec<UserId>> {
294 self.transaction(|tx| async move {
295 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
296 enum QueryUserIds {
297 UserId,
298 }
299
300 let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
301 .select_only()
302 .column(channel_buffer_collaborator::Column::UserId)
303 .filter(
304 Condition::all()
305 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
306 )
307 .into_values::<_, QueryUserIds>()
308 .all(&*tx)
309 .await?;
310
311 Ok(users)
312 })
313 .await
314 }
315
316 pub async fn update_channel_buffer(
317 &self,
318 channel_id: ChannelId,
319 user: UserId,
320 operations: &[proto::Operation],
321 ) -> Result<Vec<ConnectionId>> {
322 self.transaction(move |tx| async move {
323 self.check_user_is_channel_member(channel_id, user, &*tx)
324 .await?;
325
326 let buffer = buffer::Entity::find()
327 .filter(buffer::Column::ChannelId.eq(channel_id))
328 .one(&*tx)
329 .await?
330 .ok_or_else(|| anyhow!("no such buffer"))?;
331
332 let serialization_version = self
333 .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
334 .await?;
335
336 let operations = operations
337 .iter()
338 .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
339 .collect::<Vec<_>>();
340 if !operations.is_empty() {
341 buffer_operation::Entity::insert_many(operations)
342 .exec(&*tx)
343 .await?;
344 }
345
346 let mut connections = Vec::new();
347 let mut rows = channel_buffer_collaborator::Entity::find()
348 .filter(
349 Condition::all()
350 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
351 )
352 .stream(&*tx)
353 .await?;
354 while let Some(row) = rows.next().await {
355 let row = row?;
356 connections.push(ConnectionId {
357 id: row.connection_id as u32,
358 owner_id: row.connection_server_id.0 as u32,
359 });
360 }
361
362 Ok(connections)
363 })
364 .await
365 }
366
367 async fn get_buffer_operation_serialization_version(
368 &self,
369 buffer_id: BufferId,
370 epoch: i32,
371 tx: &DatabaseTransaction,
372 ) -> Result<i32> {
373 Ok(buffer_snapshot::Entity::find()
374 .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
375 .filter(buffer_snapshot::Column::Epoch.eq(epoch))
376 .select_only()
377 .column(buffer_snapshot::Column::OperationSerializationVersion)
378 .into_values::<_, QueryOperationSerializationVersion>()
379 .one(&*tx)
380 .await?
381 .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
382 }
383
384 async fn get_channel_buffer(
385 &self,
386 channel_id: ChannelId,
387 tx: &DatabaseTransaction,
388 ) -> Result<buffer::Model> {
389 Ok(channel::Model {
390 id: channel_id,
391 ..Default::default()
392 }
393 .find_related(buffer::Entity)
394 .one(&*tx)
395 .await?
396 .ok_or_else(|| anyhow!("no such buffer"))?)
397 }
398
399 async fn get_buffer_state(
400 &self,
401 buffer: &buffer::Model,
402 tx: &DatabaseTransaction,
403 ) -> Result<(String, Vec<proto::Operation>)> {
404 let id = buffer.id;
405 let (base_text, version) = if buffer.epoch > 0 {
406 let snapshot = buffer_snapshot::Entity::find()
407 .filter(
408 buffer_snapshot::Column::BufferId
409 .eq(id)
410 .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
411 )
412 .one(&*tx)
413 .await?
414 .ok_or_else(|| anyhow!("no such snapshot"))?;
415
416 let version = snapshot.operation_serialization_version;
417 (snapshot.text, version)
418 } else {
419 (String::new(), storage::SERIALIZATION_VERSION)
420 };
421
422 let mut rows = buffer_operation::Entity::find()
423 .filter(
424 buffer_operation::Column::BufferId
425 .eq(id)
426 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
427 )
428 .stream(&*tx)
429 .await?;
430 let mut operations = Vec::new();
431 while let Some(row) = rows.next().await {
432 operations.push(proto::Operation {
433 variant: Some(operation_from_storage(row?, version)?),
434 })
435 }
436
437 Ok((base_text, operations))
438 }
439
440 async fn snapshot_channel_buffer(
441 &self,
442 channel_id: ChannelId,
443 tx: &DatabaseTransaction,
444 ) -> Result<()> {
445 let buffer = self.get_channel_buffer(channel_id, tx).await?;
446 let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
447 if operations.is_empty() {
448 return Ok(());
449 }
450
451 let mut text_buffer = text::Buffer::new(0, 0, base_text);
452 text_buffer
453 .apply_ops(operations.into_iter().filter_map(operation_from_wire))
454 .unwrap();
455
456 let base_text = text_buffer.text();
457 let epoch = buffer.epoch + 1;
458
459 buffer_snapshot::Model {
460 buffer_id: buffer.id,
461 epoch,
462 text: base_text,
463 operation_serialization_version: storage::SERIALIZATION_VERSION,
464 }
465 .into_active_model()
466 .insert(tx)
467 .await?;
468
469 buffer::ActiveModel {
470 id: ActiveValue::Unchanged(buffer.id),
471 epoch: ActiveValue::Set(epoch),
472 ..Default::default()
473 }
474 .save(tx)
475 .await?;
476
477 Ok(())
478 }
479}
480
481fn operation_to_storage(
482 operation: &proto::Operation,
483 buffer: &buffer::Model,
484 _format: i32,
485) -> Option<buffer_operation::ActiveModel> {
486 let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
487 proto::operation::Variant::Edit(operation) => (
488 operation.replica_id,
489 operation.lamport_timestamp,
490 storage::Operation {
491 version: version_to_storage(&operation.version),
492 is_undo: false,
493 edit_ranges: operation
494 .ranges
495 .iter()
496 .map(|range| storage::Range {
497 start: range.start,
498 end: range.end,
499 })
500 .collect(),
501 edit_texts: operation.new_text.clone(),
502 undo_counts: Vec::new(),
503 },
504 ),
505 proto::operation::Variant::Undo(operation) => (
506 operation.replica_id,
507 operation.lamport_timestamp,
508 storage::Operation {
509 version: version_to_storage(&operation.version),
510 is_undo: true,
511 edit_ranges: Vec::new(),
512 edit_texts: Vec::new(),
513 undo_counts: operation
514 .counts
515 .iter()
516 .map(|entry| storage::UndoCount {
517 replica_id: entry.replica_id,
518 lamport_timestamp: entry.lamport_timestamp,
519 count: entry.count,
520 })
521 .collect(),
522 },
523 ),
524 _ => None?,
525 };
526
527 Some(buffer_operation::ActiveModel {
528 buffer_id: ActiveValue::Set(buffer.id),
529 epoch: ActiveValue::Set(buffer.epoch),
530 replica_id: ActiveValue::Set(replica_id as i32),
531 lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
532 value: ActiveValue::Set(value.encode_to_vec()),
533 })
534}
535
536fn operation_from_storage(
537 row: buffer_operation::Model,
538 _format_version: i32,
539) -> Result<proto::operation::Variant, Error> {
540 let operation =
541 storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
542 let version = version_from_storage(&operation.version);
543 Ok(if operation.is_undo {
544 proto::operation::Variant::Undo(proto::operation::Undo {
545 replica_id: row.replica_id as u32,
546 lamport_timestamp: row.lamport_timestamp as u32,
547 version,
548 counts: operation
549 .undo_counts
550 .iter()
551 .map(|entry| proto::UndoCount {
552 replica_id: entry.replica_id,
553 lamport_timestamp: entry.lamport_timestamp,
554 count: entry.count,
555 })
556 .collect(),
557 })
558 } else {
559 proto::operation::Variant::Edit(proto::operation::Edit {
560 replica_id: row.replica_id as u32,
561 lamport_timestamp: row.lamport_timestamp as u32,
562 version,
563 ranges: operation
564 .edit_ranges
565 .into_iter()
566 .map(|range| proto::Range {
567 start: range.start,
568 end: range.end,
569 })
570 .collect(),
571 new_text: operation.edit_texts,
572 })
573 })
574}
575
576fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
577 version
578 .iter()
579 .map(|entry| storage::VectorClockEntry {
580 replica_id: entry.replica_id,
581 timestamp: entry.timestamp,
582 })
583 .collect()
584}
585
586fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
587 version
588 .iter()
589 .map(|entry| proto::VectorClockEntry {
590 replica_id: entry.replica_id,
591 timestamp: entry.timestamp,
592 })
593 .collect()
594}
595
596// This is currently a manual copy of the deserialization code in the client's langauge crate
597pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
598 match operation.variant? {
599 proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
600 timestamp: clock::Lamport {
601 replica_id: edit.replica_id as text::ReplicaId,
602 value: edit.lamport_timestamp,
603 },
604 version: version_from_wire(&edit.version),
605 ranges: edit
606 .ranges
607 .into_iter()
608 .map(|range| {
609 text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
610 })
611 .collect(),
612 new_text: edit.new_text.into_iter().map(Arc::from).collect(),
613 })),
614 proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
615 timestamp: clock::Lamport {
616 replica_id: undo.replica_id as text::ReplicaId,
617 value: undo.lamport_timestamp,
618 },
619 version: version_from_wire(&undo.version),
620 counts: undo
621 .counts
622 .into_iter()
623 .map(|c| {
624 (
625 clock::Lamport {
626 replica_id: c.replica_id as text::ReplicaId,
627 value: c.lamport_timestamp,
628 },
629 c.count,
630 )
631 })
632 .collect(),
633 })),
634 _ => None,
635 }
636}
637
638fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
639 let mut version = clock::Global::new();
640 for entry in message {
641 version.observe(clock::Lamport {
642 replica_id: entry.replica_id as text::ReplicaId,
643 value: entry.timestamp,
644 });
645 }
646 version
647}
648
649fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
650 let mut message = Vec::new();
651 for entry in version.iter() {
652 message.push(proto::VectorClockEntry {
653 replica_id: entry.replica_id as u32,
654 timestamp: entry.value,
655 });
656 }
657 message
658}
659
660#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
661enum QueryOperationSerializationVersion {
662 OperationSerializationVersion,
663}
664
665mod storage {
666 #![allow(non_snake_case)]
667 use prost::Message;
668 pub const SERIALIZATION_VERSION: i32 = 1;
669
670 #[derive(Message)]
671 pub struct Operation {
672 #[prost(message, repeated, tag = "2")]
673 pub version: Vec<VectorClockEntry>,
674 #[prost(bool, tag = "3")]
675 pub is_undo: bool,
676 #[prost(message, repeated, tag = "4")]
677 pub edit_ranges: Vec<Range>,
678 #[prost(string, repeated, tag = "5")]
679 pub edit_texts: Vec<String>,
680 #[prost(message, repeated, tag = "6")]
681 pub undo_counts: Vec<UndoCount>,
682 }
683
684 #[derive(Message)]
685 pub struct VectorClockEntry {
686 #[prost(uint32, tag = "1")]
687 pub replica_id: u32,
688 #[prost(uint32, tag = "2")]
689 pub timestamp: u32,
690 }
691
692 #[derive(Message)]
693 pub struct Range {
694 #[prost(uint64, tag = "1")]
695 pub start: u64,
696 #[prost(uint64, tag = "2")]
697 pub end: u64,
698 }
699
700 #[derive(Message)]
701 pub struct UndoCount {
702 #[prost(uint32, tag = "1")]
703 pub replica_id: u32,
704 #[prost(uint32, tag = "2")]
705 pub lamport_timestamp: u32,
706 #[prost(uint32, tag = "3")]
707 pub count: u32,
708 }
709}