1use super::*;
2use prost::Message;
3use text::{EditOperation, UndoOperation};
4
5impl Database {
6 pub async fn join_channel_buffer(
7 &self,
8 channel_id: ChannelId,
9 user_id: UserId,
10 connection: ConnectionId,
11 ) -> Result<proto::JoinChannelBufferResponse> {
12 self.transaction(|tx| async move {
13 let tx = tx;
14
15 self.check_user_is_channel_member(channel_id, user_id, &tx)
16 .await?;
17
18 let buffer = channel::Model {
19 id: channel_id,
20 ..Default::default()
21 }
22 .find_related(buffer::Entity)
23 .one(&*tx)
24 .await?;
25
26 let buffer = if let Some(buffer) = buffer {
27 buffer
28 } else {
29 let buffer = buffer::ActiveModel {
30 channel_id: ActiveValue::Set(channel_id),
31 ..Default::default()
32 }
33 .insert(&*tx)
34 .await?;
35 buffer_snapshot::ActiveModel {
36 buffer_id: ActiveValue::Set(buffer.id),
37 epoch: ActiveValue::Set(0),
38 text: ActiveValue::Set(String::new()),
39 operation_serialization_version: ActiveValue::Set(
40 storage::SERIALIZATION_VERSION,
41 ),
42 }
43 .insert(&*tx)
44 .await?;
45 buffer
46 };
47
48 // Join the collaborators
49 let mut collaborators = channel_buffer_collaborator::Entity::find()
50 .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
51 .all(&*tx)
52 .await?;
53 let replica_ids = collaborators
54 .iter()
55 .map(|c| c.replica_id)
56 .collect::<HashSet<_>>();
57 let mut replica_id = ReplicaId(0);
58 while replica_ids.contains(&replica_id) {
59 replica_id.0 += 1;
60 }
61 let collaborator = channel_buffer_collaborator::ActiveModel {
62 channel_id: ActiveValue::Set(channel_id),
63 connection_id: ActiveValue::Set(connection.id as i32),
64 connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
65 user_id: ActiveValue::Set(user_id),
66 replica_id: ActiveValue::Set(replica_id),
67 ..Default::default()
68 }
69 .insert(&*tx)
70 .await?;
71 collaborators.push(collaborator);
72
73 // Assemble the buffer state
74 let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
75
76 Ok(proto::JoinChannelBufferResponse {
77 buffer_id: buffer.id.to_proto(),
78 replica_id: replica_id.to_proto() as u32,
79 base_text,
80 operations,
81 collaborators: collaborators
82 .into_iter()
83 .map(|collaborator| proto::Collaborator {
84 peer_id: Some(collaborator.connection().into()),
85 user_id: collaborator.user_id.to_proto(),
86 replica_id: collaborator.replica_id.0 as u32,
87 })
88 .collect(),
89 })
90 })
91 .await
92 }
93
94 pub async fn leave_channel_buffer(
95 &self,
96 channel_id: ChannelId,
97 connection: ConnectionId,
98 ) -> Result<Vec<ConnectionId>> {
99 self.transaction(|tx| async move {
100 self.leave_channel_buffer_internal(channel_id, connection, &*tx)
101 .await
102 })
103 .await
104 }
105
106 pub async fn leave_channel_buffer_internal(
107 &self,
108 channel_id: ChannelId,
109 connection: ConnectionId,
110 tx: &DatabaseTransaction,
111 ) -> Result<Vec<ConnectionId>> {
112 let result = channel_buffer_collaborator::Entity::delete_many()
113 .filter(
114 Condition::all()
115 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
116 .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
117 .add(
118 channel_buffer_collaborator::Column::ConnectionServerId
119 .eq(connection.owner_id as i32),
120 ),
121 )
122 .exec(&*tx)
123 .await?;
124 if result.rows_affected == 0 {
125 Err(anyhow!("not a collaborator on this project"))?;
126 }
127
128 let mut connections = Vec::new();
129 let mut rows = channel_buffer_collaborator::Entity::find()
130 .filter(
131 Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
132 )
133 .stream(&*tx)
134 .await?;
135 while let Some(row) = rows.next().await {
136 let row = row?;
137 connections.push(ConnectionId {
138 id: row.connection_id as u32,
139 owner_id: row.connection_server_id.0 as u32,
140 });
141 }
142
143 drop(rows);
144
145 if connections.is_empty() {
146 self.snapshot_buffer(channel_id, &tx).await?;
147 }
148
149 Ok(connections)
150 }
151
152 pub async fn leave_channel_buffers(
153 &self,
154 connection: ConnectionId,
155 ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
156 self.transaction(|tx| async move {
157 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
158 enum QueryChannelIds {
159 ChannelId,
160 }
161
162 let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
163 .select_only()
164 .column(channel_buffer_collaborator::Column::ChannelId)
165 .filter(Condition::all().add(
166 channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
167 ))
168 .into_values::<_, QueryChannelIds>()
169 .all(&*tx)
170 .await?;
171
172 let mut result = Vec::new();
173 for channel_id in channel_ids {
174 let collaborators = self
175 .leave_channel_buffer_internal(channel_id, connection, &*tx)
176 .await?;
177 result.push((channel_id, collaborators));
178 }
179
180 Ok(result)
181 })
182 .await
183 }
184
185 pub async fn get_channel_buffer_collaborators(
186 &self,
187 channel_id: ChannelId,
188 ) -> Result<Vec<UserId>> {
189 self.transaction(|tx| async move {
190 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
191 enum QueryUserIds {
192 UserId,
193 }
194
195 let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
196 .select_only()
197 .column(channel_buffer_collaborator::Column::UserId)
198 .filter(
199 Condition::all()
200 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
201 )
202 .into_values::<_, QueryUserIds>()
203 .all(&*tx)
204 .await?;
205
206 Ok(users)
207 })
208 .await
209 }
210
211 pub async fn update_channel_buffer(
212 &self,
213 channel_id: ChannelId,
214 user: UserId,
215 operations: &[proto::Operation],
216 ) -> Result<Vec<ConnectionId>> {
217 self.transaction(move |tx| async move {
218 self.check_user_is_channel_member(channel_id, user, &*tx)
219 .await?;
220
221 let buffer = buffer::Entity::find()
222 .filter(buffer::Column::ChannelId.eq(channel_id))
223 .one(&*tx)
224 .await?
225 .ok_or_else(|| anyhow!("no such buffer"))?;
226
227 #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
228 enum QueryVersion {
229 OperationSerializationVersion,
230 }
231
232 let serialization_version: i32 = buffer
233 .find_related(buffer_snapshot::Entity)
234 .select_only()
235 .column(buffer_snapshot::Column::OperationSerializationVersion)
236 .filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch))
237 .into_values::<_, QueryVersion>()
238 .one(&*tx)
239 .await?
240 .ok_or_else(|| anyhow!("missing buffer snapshot"))?;
241
242 let operations = operations
243 .iter()
244 .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
245 .collect::<Vec<_>>();
246 if !operations.is_empty() {
247 buffer_operation::Entity::insert_many(operations)
248 .exec(&*tx)
249 .await?;
250 }
251
252 let mut connections = Vec::new();
253 let mut rows = channel_buffer_collaborator::Entity::find()
254 .filter(
255 Condition::all()
256 .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
257 )
258 .stream(&*tx)
259 .await?;
260 while let Some(row) = rows.next().await {
261 let row = row?;
262 connections.push(ConnectionId {
263 id: row.connection_id as u32,
264 owner_id: row.connection_server_id.0 as u32,
265 });
266 }
267
268 Ok(connections)
269 })
270 .await
271 }
272
273 async fn get_buffer_state(
274 &self,
275 buffer: &buffer::Model,
276 tx: &DatabaseTransaction,
277 ) -> Result<(String, Vec<proto::Operation>)> {
278 let id = buffer.id;
279 let (base_text, version) = if buffer.epoch > 0 {
280 let snapshot = buffer_snapshot::Entity::find()
281 .filter(
282 buffer_snapshot::Column::BufferId
283 .eq(id)
284 .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
285 )
286 .one(&*tx)
287 .await?
288 .ok_or_else(|| anyhow!("no such snapshot"))?;
289
290 let version = snapshot.operation_serialization_version;
291 (snapshot.text, version)
292 } else {
293 (String::new(), storage::SERIALIZATION_VERSION)
294 };
295
296 let mut rows = buffer_operation::Entity::find()
297 .filter(
298 buffer_operation::Column::BufferId
299 .eq(id)
300 .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
301 )
302 .stream(&*tx)
303 .await?;
304 let mut operations = Vec::new();
305 while let Some(row) = rows.next().await {
306 let row = row?;
307
308 let operation = operation_from_storage(row, version)?;
309 operations.push(proto::Operation {
310 variant: Some(operation),
311 })
312 }
313
314 Ok((base_text, operations))
315 }
316
317 async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> {
318 let buffer = channel::Model {
319 id: channel_id,
320 ..Default::default()
321 }
322 .find_related(buffer::Entity)
323 .one(&*tx)
324 .await?
325 .ok_or_else(|| anyhow!("no such buffer"))?;
326
327 let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
328 if operations.is_empty() {
329 return Ok(());
330 }
331
332 let mut text_buffer = text::Buffer::new(0, 0, base_text);
333 text_buffer
334 .apply_ops(operations.into_iter().filter_map(operation_from_wire))
335 .unwrap();
336
337 let base_text = text_buffer.text();
338 let epoch = buffer.epoch + 1;
339
340 buffer_snapshot::Model {
341 buffer_id: buffer.id,
342 epoch,
343 text: base_text,
344 operation_serialization_version: storage::SERIALIZATION_VERSION,
345 }
346 .into_active_model()
347 .insert(tx)
348 .await?;
349
350 buffer::ActiveModel {
351 id: ActiveValue::Unchanged(buffer.id),
352 epoch: ActiveValue::Set(epoch),
353 ..Default::default()
354 }
355 .save(tx)
356 .await?;
357
358 Ok(())
359 }
360}
361
362fn operation_to_storage(
363 operation: &proto::Operation,
364 buffer: &buffer::Model,
365 _format: i32,
366) -> Option<buffer_operation::ActiveModel> {
367 let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
368 proto::operation::Variant::Edit(operation) => (
369 operation.replica_id,
370 operation.lamport_timestamp,
371 storage::Operation {
372 version: version_to_storage(&operation.version),
373 is_undo: false,
374 edit_ranges: operation
375 .ranges
376 .iter()
377 .map(|range| storage::Range {
378 start: range.start,
379 end: range.end,
380 })
381 .collect(),
382 edit_texts: operation.new_text.clone(),
383 undo_counts: Vec::new(),
384 },
385 ),
386 proto::operation::Variant::Undo(operation) => (
387 operation.replica_id,
388 operation.lamport_timestamp,
389 storage::Operation {
390 version: version_to_storage(&operation.version),
391 is_undo: true,
392 edit_ranges: Vec::new(),
393 edit_texts: Vec::new(),
394 undo_counts: operation
395 .counts
396 .iter()
397 .map(|entry| storage::UndoCount {
398 replica_id: entry.replica_id,
399 lamport_timestamp: entry.lamport_timestamp,
400 count: entry.count,
401 })
402 .collect(),
403 },
404 ),
405 _ => None?,
406 };
407
408 Some(buffer_operation::ActiveModel {
409 buffer_id: ActiveValue::Set(buffer.id),
410 epoch: ActiveValue::Set(buffer.epoch),
411 replica_id: ActiveValue::Set(replica_id as i32),
412 lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
413 value: ActiveValue::Set(value.encode_to_vec()),
414 })
415}
416
417fn operation_from_storage(
418 row: buffer_operation::Model,
419 _format_version: i32,
420) -> Result<proto::operation::Variant, Error> {
421 let operation =
422 storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
423 let version = version_from_storage(&operation.version);
424 Ok(if operation.is_undo {
425 proto::operation::Variant::Undo(proto::operation::Undo {
426 replica_id: row.replica_id as u32,
427 lamport_timestamp: row.lamport_timestamp as u32,
428 version,
429 counts: operation
430 .undo_counts
431 .iter()
432 .map(|entry| proto::UndoCount {
433 replica_id: entry.replica_id,
434 lamport_timestamp: entry.lamport_timestamp,
435 count: entry.count,
436 })
437 .collect(),
438 })
439 } else {
440 proto::operation::Variant::Edit(proto::operation::Edit {
441 replica_id: row.replica_id as u32,
442 lamport_timestamp: row.lamport_timestamp as u32,
443 version,
444 ranges: operation
445 .edit_ranges
446 .into_iter()
447 .map(|range| proto::Range {
448 start: range.start,
449 end: range.end,
450 })
451 .collect(),
452 new_text: operation.edit_texts,
453 })
454 })
455}
456
457fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
458 version
459 .iter()
460 .map(|entry| storage::VectorClockEntry {
461 replica_id: entry.replica_id,
462 timestamp: entry.timestamp,
463 })
464 .collect()
465}
466
467fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
468 version
469 .iter()
470 .map(|entry| proto::VectorClockEntry {
471 replica_id: entry.replica_id,
472 timestamp: entry.timestamp,
473 })
474 .collect()
475}
476
477// This is currently a manual copy of the deserialization code in the client's langauge crate
478pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
479 match operation.variant? {
480 proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
481 timestamp: clock::Lamport {
482 replica_id: edit.replica_id as text::ReplicaId,
483 value: edit.lamport_timestamp,
484 },
485 version: version_from_wire(&edit.version),
486 ranges: edit
487 .ranges
488 .into_iter()
489 .map(|range| {
490 text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
491 })
492 .collect(),
493 new_text: edit.new_text.into_iter().map(Arc::from).collect(),
494 })),
495 proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
496 timestamp: clock::Lamport {
497 replica_id: undo.replica_id as text::ReplicaId,
498 value: undo.lamport_timestamp,
499 },
500 version: version_from_wire(&undo.version),
501 counts: undo
502 .counts
503 .into_iter()
504 .map(|c| {
505 (
506 clock::Lamport {
507 replica_id: c.replica_id as text::ReplicaId,
508 value: c.lamport_timestamp,
509 },
510 c.count,
511 )
512 })
513 .collect(),
514 })),
515 _ => None,
516 }
517}
518
519fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
520 let mut version = clock::Global::new();
521 for entry in message {
522 version.observe(clock::Lamport {
523 replica_id: entry.replica_id as text::ReplicaId,
524 value: entry.timestamp,
525 });
526 }
527 version
528}
529
530mod storage {
531 #![allow(non_snake_case)]
532 use prost::Message;
533 pub const SERIALIZATION_VERSION: i32 = 1;
534
535 #[derive(Message)]
536 pub struct Operation {
537 #[prost(message, repeated, tag = "2")]
538 pub version: Vec<VectorClockEntry>,
539 #[prost(bool, tag = "3")]
540 pub is_undo: bool,
541 #[prost(message, repeated, tag = "4")]
542 pub edit_ranges: Vec<Range>,
543 #[prost(string, repeated, tag = "5")]
544 pub edit_texts: Vec<String>,
545 #[prost(message, repeated, tag = "6")]
546 pub undo_counts: Vec<UndoCount>,
547 }
548
549 #[derive(Message)]
550 pub struct VectorClockEntry {
551 #[prost(uint32, tag = "1")]
552 pub replica_id: u32,
553 #[prost(uint32, tag = "2")]
554 pub timestamp: u32,
555 }
556
557 #[derive(Message)]
558 pub struct Range {
559 #[prost(uint64, tag = "1")]
560 pub start: u64,
561 #[prost(uint64, tag = "2")]
562 pub end: u64,
563 }
564
565 #[derive(Message)]
566 pub struct UndoCount {
567 #[prost(uint32, tag = "1")]
568 pub replica_id: u32,
569 #[prost(uint32, tag = "2")]
570 pub lamport_timestamp: u32,
571 #[prost(uint32, tag = "3")]
572 pub count: u32,
573 }
574}