1use super::*;
2use prost::Message;
3use text::{EditOperation, InsertionTimestamp, UndoOperation};
4
5impl Database {
6 pub async fn join_channel_buffer(
7 &self,
8 channel_id: ChannelId,
9 user_id: UserId,
10 connection: ConnectionId,
11 ) -> Result<proto::JoinChannelBufferResponse> {
12 self.transaction(|tx| async move {
13 let tx = tx;
14
15 self.check_user_is_channel_member(channel_id, user_id, &tx)
16 .await?;
17
18 let buffer = channel::Model {
19 id: channel_id,
20 ..Default::default()
21 }
22 .find_related(buffer::Entity)
23 .one(&*tx)
24 .await?;
25
26 let buffer = if let Some(buffer) = buffer {
27 buffer
28 } else {
29 let buffer = buffer::ActiveModel {
30 channel_id: ActiveValue::Set(channel_id),
31 ..Default::default()
32 }
33 .insert(&*tx)
34 .await?;
35 buffer_snapshot::ActiveModel {
36 buffer_id: ActiveValue::Set(buffer.id),
37 epoch: ActiveValue::Set(0),
38 text: ActiveValue::Set(String::new()),
39 operation_serialization_version: ActiveValue::Set(
40 storage::SERIALIZATION_VERSION,
41 ),
42 }
43 .insert(&*tx)
44 .await?;
45 buffer
46 };
47
48 // Join the collaborators
49 let mut collaborators = channel_buffer_collaborator::Entity::find()
50 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
51 .all(&*tx)
52 .await?;
53 let replica_ids = collaborators
54 .iter()
55 .map(|c| c.replica_id)
56 .collect::<HashSet<_>>();
57 let mut replica_id = ReplicaId(0);
58 while replica_ids.contains(&replica_id) {
59 replica_id.0 += 1;
60 }
61 let collaborator = channel_buffer_collaborator::ActiveModel {
62 channel_id: ActiveValue::Set(channel_id),
63 connection_id: ActiveValue::Set(connection.id as i32),
64 connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
65 user_id: ActiveValue::Set(user_id),
66 replica_id: ActiveValue::Set(replica_id),
67 ..Default::default()
68 }
69 .insert(&*tx)
70 .await?;
71 collaborators.push(collaborator);
72
73 // Assemble the buffer state
74 let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
75
76 Ok(proto::JoinChannelBufferResponse {
77 buffer_id: buffer.id.to_proto(),
78 replica_id: replica_id.to_proto() as u32,
79 base_text,
80 operations,
81 collaborators: collaborators
82 .into_iter()
83 .map(|collaborator| proto::Collaborator {
84 peer_id: Some(collaborator.connection().into()),
85 user_id: collaborator.user_id.to_proto(),
86 replica_id: collaborator.replica_id.0 as u32,
87 })
88 .collect(),
89 })
90 })
91 .await
92 }
93
94 pub async fn leave_channel_buffer(
95 &self,
96 channel_id: ChannelId,
97 connection: ConnectionId,
98 ) -> Result<Vec<ConnectionId>> {
99 self.transaction(|tx| async move {
100 self.leave_channel_buffer_internal(channel_id, connection, &*tx)
101 .await
102 })
103 .await
104 }
105
106 pub async fn leave_channel_buffer_internal(
107 &self,
108 channel_id: ChannelId,
109 connection: ConnectionId,
110 tx: &DatabaseTransaction,
111 ) -> Result<Vec<ConnectionId>> {
112 let result = channel_buffer_collaborator::Entity::delete_many()
113 .filter(
114 Condition::all()
115 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
116 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
117 .add(
118 channel_buffer_collaborator::Column::ConnectionServerId
119 .eq(connection.owner_id as i32),
120 ),
121 )
122 .exec(&*tx)
123 .await?;
124 if result.rows_affected == 0 {
125 Err(anyhow!("not a collaborator on this project"))?;
126 }
127
128 let mut connections = Vec::new();
129 let mut rows = channel_buffer_collaborator::Entity::find()
130 .filter(
131 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
132 )
133 .stream(&*tx)
134 .await?;
135 while let Some(row) = rows.next().await {
136 let row = row?;
137 connections.push(ConnectionId {
138 id: row.connection_id as u32,
139 owner_id: row.connection_server_id.0 as u32,
140 });
141 }
142
143 drop(rows);
144
145 if connections.is_empty() {
146 self.snapshot_buffer(channel_id, &tx).await?;
147 }
148
149 Ok(connections)
150 }
151
152 pub async fn leave_channel_buffers(
153 &self,
154 connection: ConnectionId,
155 ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
156 self.transaction(|tx| async move {
157 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
158 enum QueryChannelIds {
159 ChannelId,
160 }
161
162 let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
163 .select_only()
164 .column(channel_buffer_collaborator::Column::ChannelId)
165 .filter(Condition::all().add(
166 channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
167 ))
168 .into_values::<_, QueryChannelIds>()
169 .all(&*tx)
170 .await?;
171
172 let mut result = Vec::new();
173 for channel_id in channel_ids {
174 let collaborators = self
175 .leave_channel_buffer_internal(channel_id, connection, &*tx)
176 .await?;
177 result.push((channel_id, collaborators));
178 }
179
180 Ok(result)
181 })
182 .await
183 }
184
185 #[cfg(debug_assertions)]
186 pub async fn get_channel_buffer_collaborators(
187 &self,
188 channel_id: ChannelId,
189 ) -> Result<Vec<UserId>> {
190 self.transaction(|tx| async move {
191 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
192 enum QueryUserIds {
193 UserId,
194 }
195
196 let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
197 .select_only()
198 .column(channel_buffer_collaborator::Column::UserId)
199 .filter(
200 Condition::all()
201 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
202 )
203 .into_values::<_, QueryUserIds>()
204 .all(&*tx)
205 .await?;
206
207 Ok(users)
208 })
209 .await
210 }
211
212 pub async fn update_channel_buffer(
213 &self,
214 channel_id: ChannelId,
215 user: UserId,
216 operations: &[proto::Operation],
217 ) -> Result<Vec<ConnectionId>> {
218 self.transaction(move |tx| async move {
219 self.check_user_is_channel_member(channel_id, user, &*tx)
220 .await?;
221
222 let buffer = buffer::Entity::find()
223 .filter(buffer::Column::ChannelId.eq(channel_id))
224 .one(&*tx)
225 .await?
226 .ok_or_else(|| anyhow!("no such buffer"))?;
227
228 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
229 enum QueryVersion {
230 OperationSerializationVersion,
231 }
232
233 let serialization_version: i32 = buffer
234 .find_related(buffer_snapshot::Entity)
235 .select_only()
236 .column(buffer_snapshot::Column::OperationSerializationVersion)
237 .filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch))
238 .into_values::<_, QueryVersion>()
239 .one(&*tx)
240 .await?
241 .ok_or_else(|| anyhow!("missing buffer snapshot"))?;
242
243 let operations = operations
244 .iter()
245 .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
246 .collect::<Vec<_>>();
247 if !operations.is_empty() {
248 buffer_operation::Entity::insert_many(operations)
249 .exec(&*tx)
250 .await?;
251 }
252
253 let mut connections = Vec::new();
254 let mut rows = channel_buffer_collaborator::Entity::find()
255 .filter(
256 Condition::all()
257 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
258 )
259 .stream(&*tx)
260 .await?;
261 while let Some(row) = rows.next().await {
262 let row = row?;
263 connections.push(ConnectionId {
264 id: row.connection_id as u32,
265 owner_id: row.connection_server_id.0 as u32,
266 });
267 }
268
269 Ok(connections)
270 })
271 .await
272 }
273
274 async fn get_buffer_state(
275 &self,
276 buffer: &buffer::Model,
277 tx: &DatabaseTransaction,
278 ) -> Result<(String, Vec<proto::Operation>)> {
279 let id = buffer.id;
280 let (base_text, version) = if buffer.epoch > 0 {
281 let snapshot = buffer_snapshot::Entity::find()
282 .filter(
283 buffer_snapshot::Column::BufferId
284 .eq(id)
285 .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
286 )
287 .one(&*tx)
288 .await?
289 .ok_or_else(|| anyhow!("no such snapshot"))?;
290
291 let version = snapshot.operation_serialization_version;
292 (snapshot.text, version)
293 } else {
294 (String::new(), storage::SERIALIZATION_VERSION)
295 };
296
297 let mut rows = buffer_operation::Entity::find()
298 .filter(
299 buffer_operation::Column::BufferId
300 .eq(id)
301 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
302 )
303 .stream(&*tx)
304 .await?;
305 let mut operations = Vec::new();
306 while let Some(row) = rows.next().await {
307 let row = row?;
308
309 let operation = operation_from_storage(row, version)?;
310 operations.push(proto::Operation {
311 variant: Some(operation),
312 })
313 }
314
315 Ok((base_text, operations))
316 }
317
318 async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> {
319 let buffer = channel::Model {
320 id: channel_id,
321 ..Default::default()
322 }
323 .find_related(buffer::Entity)
324 .one(&*tx)
325 .await?
326 .ok_or_else(|| anyhow!("no such buffer"))?;
327
328 let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
329 if operations.is_empty() {
330 return Ok(());
331 }
332
333 let mut text_buffer = text::Buffer::new(0, 0, base_text);
334 text_buffer
335 .apply_ops(operations.into_iter().filter_map(operation_from_wire))
336 .unwrap();
337
338 let base_text = text_buffer.text();
339 let epoch = buffer.epoch + 1;
340
341 buffer_snapshot::Model {
342 buffer_id: buffer.id,
343 epoch,
344 text: base_text,
345 operation_serialization_version: storage::SERIALIZATION_VERSION,
346 }
347 .into_active_model()
348 .insert(tx)
349 .await?;
350
351 buffer::ActiveModel {
352 id: ActiveValue::Unchanged(buffer.id),
353 epoch: ActiveValue::Set(epoch),
354 ..Default::default()
355 }
356 .save(tx)
357 .await?;
358
359 Ok(())
360 }
361}
362
363fn operation_to_storage(
364 operation: &proto::Operation,
365 buffer: &buffer::Model,
366 _format: i32,
367) -> Option<buffer_operation::ActiveModel> {
368 let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
369 proto::operation::Variant::Edit(operation) => (
370 operation.replica_id,
371 operation.lamport_timestamp,
372 storage::Operation {
373 local_timestamp: operation.local_timestamp,
374 version: version_to_storage(&operation.version),
375 is_undo: false,
376 edit_ranges: operation
377 .ranges
378 .iter()
379 .map(|range| storage::Range {
380 start: range.start,
381 end: range.end,
382 })
383 .collect(),
384 edit_texts: operation.new_text.clone(),
385 undo_counts: Vec::new(),
386 },
387 ),
388 proto::operation::Variant::Undo(operation) => (
389 operation.replica_id,
390 operation.lamport_timestamp,
391 storage::Operation {
392 local_timestamp: operation.local_timestamp,
393 version: version_to_storage(&operation.version),
394 is_undo: true,
395 edit_ranges: Vec::new(),
396 edit_texts: Vec::new(),
397 undo_counts: operation
398 .counts
399 .iter()
400 .map(|entry| storage::UndoCount {
401 replica_id: entry.replica_id,
402 local_timestamp: entry.local_timestamp,
403 count: entry.count,
404 })
405 .collect(),
406 },
407 ),
408 _ => None?,
409 };
410
411 Some(buffer_operation::ActiveModel {
412 buffer_id: ActiveValue::Set(buffer.id),
413 epoch: ActiveValue::Set(buffer.epoch),
414 replica_id: ActiveValue::Set(replica_id as i32),
415 lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
416 value: ActiveValue::Set(value.encode_to_vec()),
417 })
418}
419
420fn operation_from_storage(
421 row: buffer_operation::Model,
422 _format_version: i32,
423) -> Result<proto::operation::Variant, Error> {
424 let operation =
425 storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
426 let version = version_from_storage(&operation.version);
427 Ok(if operation.is_undo {
428 proto::operation::Variant::Undo(proto::operation::Undo {
429 replica_id: row.replica_id as u32,
430 local_timestamp: operation.local_timestamp as u32,
431 lamport_timestamp: row.lamport_timestamp as u32,
432 version,
433 counts: operation
434 .undo_counts
435 .iter()
436 .map(|entry| proto::UndoCount {
437 replica_id: entry.replica_id,
438 local_timestamp: entry.local_timestamp,
439 count: entry.count,
440 })
441 .collect(),
442 })
443 } else {
444 proto::operation::Variant::Edit(proto::operation::Edit {
445 replica_id: row.replica_id as u32,
446 local_timestamp: operation.local_timestamp as u32,
447 lamport_timestamp: row.lamport_timestamp as u32,
448 version,
449 ranges: operation
450 .edit_ranges
451 .into_iter()
452 .map(|range| proto::Range {
453 start: range.start,
454 end: range.end,
455 })
456 .collect(),
457 new_text: operation.edit_texts,
458 })
459 })
460}
461
462fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
463 version
464 .iter()
465 .map(|entry| storage::VectorClockEntry {
466 replica_id: entry.replica_id,
467 timestamp: entry.timestamp,
468 })
469 .collect()
470}
471
472fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
473 version
474 .iter()
475 .map(|entry| proto::VectorClockEntry {
476 replica_id: entry.replica_id,
477 timestamp: entry.timestamp,
478 })
479 .collect()
480}
481
482// This is currently a manual copy of the deserialization code in the client's langauge crate
483pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
484 match operation.variant? {
485 proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
486 timestamp: InsertionTimestamp {
487 replica_id: edit.replica_id as text::ReplicaId,
488 local: edit.local_timestamp,
489 lamport: edit.lamport_timestamp,
490 },
491 version: version_from_wire(&edit.version),
492 ranges: edit
493 .ranges
494 .into_iter()
495 .map(|range| {
496 text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
497 })
498 .collect(),
499 new_text: edit.new_text.into_iter().map(Arc::from).collect(),
500 })),
501 proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo {
502 lamport_timestamp: clock::Lamport {
503 replica_id: undo.replica_id as text::ReplicaId,
504 value: undo.lamport_timestamp,
505 },
506 undo: UndoOperation {
507 id: clock::Local {
508 replica_id: undo.replica_id as text::ReplicaId,
509 value: undo.local_timestamp,
510 },
511 version: version_from_wire(&undo.version),
512 counts: undo
513 .counts
514 .into_iter()
515 .map(|c| {
516 (
517 clock::Local {
518 replica_id: c.replica_id as text::ReplicaId,
519 value: c.local_timestamp,
520 },
521 c.count,
522 )
523 })
524 .collect(),
525 },
526 }),
527 _ => None,
528 }
529}
530
531fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
532 let mut version = clock::Global::new();
533 for entry in message {
534 version.observe(clock::Local {
535 replica_id: entry.replica_id as text::ReplicaId,
536 value: entry.timestamp,
537 });
538 }
539 version
540}
541
542mod storage {
543 #![allow(non_snake_case)]
544 use prost::Message;
545 pub const SERIALIZATION_VERSION: i32 = 1;
546
547 #[derive(Message)]
548 pub struct Operation {
549 #[prost(uint32, tag = "1")]
550 pub local_timestamp: u32,
551 #[prost(message, repeated, tag = "2")]
552 pub version: Vec<VectorClockEntry>,
553 #[prost(bool, tag = "3")]
554 pub is_undo: bool,
555 #[prost(message, repeated, tag = "4")]
556 pub edit_ranges: Vec<Range>,
557 #[prost(string, repeated, tag = "5")]
558 pub edit_texts: Vec<String>,
559 #[prost(message, repeated, tag = "6")]
560 pub undo_counts: Vec<UndoCount>,
561 }
562
563 #[derive(Message)]
564 pub struct VectorClockEntry {
565 #[prost(uint32, tag = "1")]
566 pub replica_id: u32,
567 #[prost(uint32, tag = "2")]
568 pub timestamp: u32,
569 }
570
571 #[derive(Message)]
572 pub struct Range {
573 #[prost(uint64, tag = "1")]
574 pub start: u64,
575 #[prost(uint64, tag = "2")]
576 pub end: u64,
577 }
578
579 #[derive(Message)]
580 pub struct UndoCount {
581 #[prost(uint32, tag = "1")]
582 pub replica_id: u32,
583 #[prost(uint32, tag = "2")]
584 pub local_timestamp: u32,
585 #[prost(uint32, tag = "3")]
586 pub count: u32,
587 }
588}