1#[cfg(test)]
2pub mod tests;
3
4#[cfg(test)]
5pub use tests::TestDb;
6
7mod ids;
8mod queries;
9mod tables;
10
11use crate::{executor::Executor, Error, Result};
12use anyhow::anyhow;
13use collections::{BTreeMap, HashMap, HashSet};
14use dashmap::DashMap;
15use futures::StreamExt;
16use rand::{prelude::StdRng, Rng, SeedableRng};
17use rpc::{
18 proto::{self},
19 ConnectionId,
20};
21use sea_orm::{
22 entity::prelude::*,
23 sea_query::{Alias, Expr, OnConflict},
24 ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
25 FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
26 TransactionTrait,
27};
28use serde::{Deserialize, Serialize};
29use sqlx::{
30 migrate::{Migrate, Migration, MigrationSource},
31 Connection,
32};
33use std::{
34 fmt::Write as _,
35 future::Future,
36 marker::PhantomData,
37 ops::{Deref, DerefMut},
38 path::Path,
39 rc::Rc,
40 sync::Arc,
41 time::Duration,
42};
43use tables::*;
44use tokio::sync::{Mutex, OwnedMutexGuard};
45
46pub use ids::*;
47pub use sea_orm::ConnectOptions;
48pub use tables::user::Model as User;
49
50pub struct Database {
51 options: ConnectOptions,
52 pool: DatabaseConnection,
53 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
54 rng: Mutex<StdRng>,
55 executor: Executor,
56 notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
57 notification_kinds_by_name: HashMap<String, NotificationKindId>,
58 #[cfg(test)]
59 runtime: Option<tokio::runtime::Runtime>,
60}
61
62// The `Database` type has so many methods that its impl blocks are split into
63// separate files in the `queries` folder.
64impl Database {
65 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
66 sqlx::any::install_default_drivers();
67 Ok(Self {
68 options: options.clone(),
69 pool: sea_orm::Database::connect(options).await?,
70 rooms: DashMap::with_capacity(16384),
71 rng: Mutex::new(StdRng::seed_from_u64(0)),
72 notification_kinds_by_id: HashMap::default(),
73 notification_kinds_by_name: HashMap::default(),
74 executor,
75 #[cfg(test)]
76 runtime: None,
77 })
78 }
79
80 #[cfg(test)]
81 pub fn reset(&self) {
82 self.rooms.clear();
83 }
84
85 pub async fn migrate(
86 &self,
87 migrations_path: &Path,
88 ignore_checksum_mismatch: bool,
89 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
90 let migrations = MigrationSource::resolve(migrations_path)
91 .await
92 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
93
94 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
95
96 connection.ensure_migrations_table().await?;
97 let applied_migrations: HashMap<_, _> = connection
98 .list_applied_migrations()
99 .await?
100 .into_iter()
101 .map(|m| (m.version, m))
102 .collect();
103
104 let mut new_migrations = Vec::new();
105 for migration in migrations {
106 match applied_migrations.get(&migration.version) {
107 Some(applied_migration) => {
108 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
109 {
110 Err(anyhow!(
111 "checksum mismatch for applied migration {}",
112 migration.description
113 ))?;
114 }
115 }
116 None => {
117 let elapsed = connection.apply(&migration).await?;
118 new_migrations.push((migration, elapsed));
119 }
120 }
121 }
122
123 Ok(new_migrations)
124 }
125
126 pub async fn initialize_static_data(&mut self) -> Result<()> {
127 self.initialize_notification_kinds().await?;
128 Ok(())
129 }
130
131 pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
132 where
133 F: Send + Fn(TransactionHandle) -> Fut,
134 Fut: Send + Future<Output = Result<T>>,
135 {
136 let body = async {
137 let mut i = 0;
138 loop {
139 let (tx, result) = self.with_transaction(&f).await?;
140 match result {
141 Ok(result) => match tx.commit().await.map_err(Into::into) {
142 Ok(()) => return Ok(result),
143 Err(error) => {
144 if !self.retry_on_serialization_error(&error, i).await {
145 return Err(error);
146 }
147 }
148 },
149 Err(error) => {
150 tx.rollback().await?;
151 if !self.retry_on_serialization_error(&error, i).await {
152 return Err(error);
153 }
154 }
155 }
156 i += 1;
157 }
158 };
159
160 self.run(body).await
161 }
162
163 async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
164 where
165 F: Send + Fn(TransactionHandle) -> Fut,
166 Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
167 {
168 let body = async {
169 let mut i = 0;
170 loop {
171 let (tx, result) = self.with_transaction(&f).await?;
172 match result {
173 Ok(Some((room_id, data))) => {
174 let lock = self.rooms.entry(room_id).or_default().clone();
175 let _guard = lock.lock_owned().await;
176 match tx.commit().await.map_err(Into::into) {
177 Ok(()) => {
178 return Ok(Some(RoomGuard {
179 data,
180 _guard,
181 _not_send: PhantomData,
182 }));
183 }
184 Err(error) => {
185 if !self.retry_on_serialization_error(&error, i).await {
186 return Err(error);
187 }
188 }
189 }
190 }
191 Ok(None) => match tx.commit().await.map_err(Into::into) {
192 Ok(()) => return Ok(None),
193 Err(error) => {
194 if !self.retry_on_serialization_error(&error, i).await {
195 return Err(error);
196 }
197 }
198 },
199 Err(error) => {
200 tx.rollback().await?;
201 if !self.retry_on_serialization_error(&error, i).await {
202 return Err(error);
203 }
204 }
205 }
206 i += 1;
207 }
208 };
209
210 self.run(body).await
211 }
212
213 async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
214 where
215 F: Send + Fn(TransactionHandle) -> Fut,
216 Fut: Send + Future<Output = Result<T>>,
217 {
218 let body = async {
219 let mut i = 0;
220 loop {
221 let lock = self.rooms.entry(room_id).or_default().clone();
222 let _guard = lock.lock_owned().await;
223 let (tx, result) = self.with_transaction(&f).await?;
224 match result {
225 Ok(data) => match tx.commit().await.map_err(Into::into) {
226 Ok(()) => {
227 return Ok(RoomGuard {
228 data,
229 _guard,
230 _not_send: PhantomData,
231 });
232 }
233 Err(error) => {
234 if !self.retry_on_serialization_error(&error, i).await {
235 return Err(error);
236 }
237 }
238 },
239 Err(error) => {
240 tx.rollback().await?;
241 if !self.retry_on_serialization_error(&error, i).await {
242 return Err(error);
243 }
244 }
245 }
246 i += 1;
247 }
248 };
249
250 self.run(body).await
251 }
252
253 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
254 where
255 F: Send + Fn(TransactionHandle) -> Fut,
256 Fut: Send + Future<Output = Result<T>>,
257 {
258 let tx = self
259 .pool
260 .begin_with_config(Some(IsolationLevel::Serializable), None)
261 .await?;
262
263 let mut tx = Arc::new(Some(tx));
264 let result = f(TransactionHandle(tx.clone())).await;
265 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
266 return Err(anyhow!(
267 "couldn't complete transaction because it's still in use"
268 ))?;
269 };
270
271 Ok((tx, result))
272 }
273
274 async fn run<F, T>(&self, future: F) -> Result<T>
275 where
276 F: Future<Output = Result<T>>,
277 {
278 #[cfg(test)]
279 {
280 if let Executor::Deterministic(executor) = &self.executor {
281 executor.simulate_random_delay().await;
282 }
283
284 self.runtime.as_ref().unwrap().block_on(future)
285 }
286
287 #[cfg(not(test))]
288 {
289 future.await
290 }
291 }
292
293 async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
294 // If the error is due to a failure to serialize concurrent transactions, then retry
295 // this transaction after a delay. With each subsequent retry, double the delay duration.
296 // Also vary the delay randomly in order to ensure different database connections retry
297 // at different times.
298 if is_serialization_error(error) {
299 let base_delay = 4_u64 << prev_attempt_count.min(16);
300 let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
301 log::info!(
302 "retrying transaction after serialization error. delay: {} ms.",
303 randomized_delay
304 );
305 self.executor
306 .sleep(Duration::from_millis(randomized_delay as u64))
307 .await;
308 true
309 } else {
310 false
311 }
312 }
313}
314
315fn is_serialization_error(error: &Error) -> bool {
316 const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
317 match error {
318 Error::Database(
319 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
320 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
321 ) if error
322 .as_database_error()
323 .and_then(|error| error.code())
324 .as_deref()
325 == Some(SERIALIZATION_FAILURE_CODE) =>
326 {
327 true
328 }
329 _ => false,
330 }
331}
332
333pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
334
335impl Deref for TransactionHandle {
336 type Target = DatabaseTransaction;
337
338 fn deref(&self) -> &Self::Target {
339 self.0.as_ref().as_ref().unwrap()
340 }
341}
342
343pub struct RoomGuard<T> {
344 data: T,
345 _guard: OwnedMutexGuard<()>,
346 _not_send: PhantomData<Rc<()>>,
347}
348
349impl<T> Deref for RoomGuard<T> {
350 type Target = T;
351
352 fn deref(&self) -> &T {
353 &self.data
354 }
355}
356
357impl<T> DerefMut for RoomGuard<T> {
358 fn deref_mut(&mut self) -> &mut T {
359 &mut self.data
360 }
361}
362
363impl<T> RoomGuard<T> {
364 pub fn into_inner(self) -> T {
365 self.data
366 }
367}
368
369#[derive(Clone, Debug, PartialEq, Eq)]
370pub enum Contact {
371 Accepted { user_id: UserId, busy: bool },
372 Outgoing { user_id: UserId },
373 Incoming { user_id: UserId },
374}
375
376impl Contact {
377 pub fn user_id(&self) -> UserId {
378 match self {
379 Contact::Accepted { user_id, .. } => *user_id,
380 Contact::Outgoing { user_id } => *user_id,
381 Contact::Incoming { user_id, .. } => *user_id,
382 }
383 }
384}
385
386pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
387
388pub struct CreatedChannelMessage {
389 pub message_id: MessageId,
390 pub participant_connection_ids: Vec<ConnectionId>,
391 pub channel_members: Vec<UserId>,
392 pub notifications: NotificationBatch,
393}
394
395#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
396pub struct Invite {
397 pub email_address: String,
398 pub email_confirmation_code: String,
399}
400
401#[derive(Clone, Debug, Deserialize)]
402pub struct NewSignup {
403 pub email_address: String,
404 pub platform_mac: bool,
405 pub platform_windows: bool,
406 pub platform_linux: bool,
407 pub editor_features: Vec<String>,
408 pub programming_languages: Vec<String>,
409 pub device_id: Option<String>,
410 pub added_to_mailing_list: bool,
411 pub created_at: Option<DateTime>,
412}
413
414#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
415pub struct WaitlistSummary {
416 pub count: i64,
417 pub linux_count: i64,
418 pub mac_count: i64,
419 pub windows_count: i64,
420 pub unknown_count: i64,
421}
422
423#[derive(Debug, Serialize, Deserialize)]
424pub struct NewUserParams {
425 pub github_login: String,
426 pub github_user_id: i32,
427}
428
429#[derive(Debug)]
430pub struct NewUserResult {
431 pub user_id: UserId,
432 pub metrics_id: String,
433 pub inviting_user_id: Option<UserId>,
434 pub signup_device_id: Option<String>,
435}
436
437#[derive(Debug)]
438pub struct MoveChannelResult {
439 pub participants_to_update: HashMap<UserId, ChannelsForUser>,
440 pub participants_to_remove: HashSet<UserId>,
441 pub moved_channels: HashSet<ChannelId>,
442}
443
444#[derive(Debug)]
445pub struct RenameChannelResult {
446 pub channel: Channel,
447 pub participants_to_update: HashMap<UserId, Channel>,
448}
449
450#[derive(Debug)]
451pub struct CreateChannelResult {
452 pub channel: Channel,
453 pub participants_to_update: Vec<(UserId, ChannelsForUser)>,
454}
455
456#[derive(Debug)]
457pub struct SetChannelVisibilityResult {
458 pub participants_to_update: HashMap<UserId, ChannelsForUser>,
459 pub participants_to_remove: HashSet<UserId>,
460 pub channels_to_remove: Vec<ChannelId>,
461}
462
463#[derive(Debug)]
464pub struct MembershipUpdated {
465 pub channel_id: ChannelId,
466 pub new_channels: ChannelsForUser,
467 pub removed_channels: Vec<ChannelId>,
468}
469
470#[derive(Debug)]
471pub enum SetMemberRoleResult {
472 InviteUpdated(Channel),
473 MembershipUpdated(MembershipUpdated),
474}
475
476#[derive(Debug)]
477pub struct InviteMemberResult {
478 pub channel: Channel,
479 pub notifications: NotificationBatch,
480}
481
482#[derive(Debug)]
483pub struct RespondToChannelInvite {
484 pub membership_update: Option<MembershipUpdated>,
485 pub notifications: NotificationBatch,
486}
487
488#[derive(Debug)]
489pub struct RemoveChannelMemberResult {
490 pub membership_update: MembershipUpdated,
491 pub notification_id: Option<NotificationId>,
492}
493
494#[derive(Debug, PartialEq, Eq, Hash)]
495pub struct Channel {
496 pub id: ChannelId,
497 pub name: String,
498 pub visibility: ChannelVisibility,
499 pub role: ChannelRole,
500 pub parent_path: Vec<ChannelId>,
501}
502
503impl Channel {
504 fn from_model(value: channel::Model, role: ChannelRole) -> Self {
505 Channel {
506 id: value.id,
507 visibility: value.visibility,
508 name: value.clone().name,
509 role,
510 parent_path: value.ancestors().collect(),
511 }
512 }
513
514 pub fn to_proto(&self) -> proto::Channel {
515 proto::Channel {
516 id: self.id.to_proto(),
517 name: self.name.clone(),
518 visibility: self.visibility.into(),
519 role: self.role.into(),
520 parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(),
521 }
522 }
523}
524
525#[derive(Debug, PartialEq, Eq, Hash)]
526pub struct ChannelMember {
527 pub role: ChannelRole,
528 pub user_id: UserId,
529 pub kind: proto::channel_member::Kind,
530}
531
532impl ChannelMember {
533 pub fn to_proto(&self) -> proto::ChannelMember {
534 proto::ChannelMember {
535 role: self.role.into(),
536 user_id: self.user_id.to_proto(),
537 kind: self.kind.into(),
538 }
539 }
540}
541
542#[derive(Debug, PartialEq)]
543pub struct ChannelsForUser {
544 pub channels: Vec<Channel>,
545 pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
546 pub unseen_buffer_changes: Vec<proto::UnseenChannelBufferChange>,
547 pub channel_messages: Vec<proto::UnseenChannelMessage>,
548}
549
550#[derive(Debug)]
551pub struct RejoinedChannelBuffer {
552 pub buffer: proto::RejoinedChannelBuffer,
553 pub old_connection_id: ConnectionId,
554}
555
556#[derive(Clone)]
557pub struct JoinRoom {
558 pub room: proto::Room,
559 pub channel_id: Option<ChannelId>,
560 pub channel_members: Vec<UserId>,
561}
562
563pub struct RejoinedRoom {
564 pub room: proto::Room,
565 pub rejoined_projects: Vec<RejoinedProject>,
566 pub reshared_projects: Vec<ResharedProject>,
567 pub channel_id: Option<ChannelId>,
568 pub channel_members: Vec<UserId>,
569}
570
571pub struct ResharedProject {
572 pub id: ProjectId,
573 pub old_connection_id: ConnectionId,
574 pub collaborators: Vec<ProjectCollaborator>,
575 pub worktrees: Vec<proto::WorktreeMetadata>,
576}
577
578pub struct RejoinedProject {
579 pub id: ProjectId,
580 pub old_connection_id: ConnectionId,
581 pub collaborators: Vec<ProjectCollaborator>,
582 pub worktrees: Vec<RejoinedWorktree>,
583 pub language_servers: Vec<proto::LanguageServer>,
584}
585
586#[derive(Debug)]
587pub struct RejoinedWorktree {
588 pub id: u64,
589 pub abs_path: String,
590 pub root_name: String,
591 pub visible: bool,
592 pub updated_entries: Vec<proto::Entry>,
593 pub removed_entries: Vec<u64>,
594 pub updated_repositories: Vec<proto::RepositoryEntry>,
595 pub removed_repositories: Vec<u64>,
596 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
597 pub settings_files: Vec<WorktreeSettingsFile>,
598 pub scan_id: u64,
599 pub completed_scan_id: u64,
600}
601
602pub struct LeftRoom {
603 pub room: proto::Room,
604 pub channel_id: Option<ChannelId>,
605 pub channel_members: Vec<UserId>,
606 pub left_projects: HashMap<ProjectId, LeftProject>,
607 pub canceled_calls_to_user_ids: Vec<UserId>,
608 pub deleted: bool,
609}
610
611pub struct RefreshedRoom {
612 pub room: proto::Room,
613 pub channel_id: Option<ChannelId>,
614 pub channel_members: Vec<UserId>,
615 pub stale_participant_user_ids: Vec<UserId>,
616 pub canceled_calls_to_user_ids: Vec<UserId>,
617}
618
619pub struct RefreshedChannelBuffer {
620 pub connection_ids: Vec<ConnectionId>,
621 pub collaborators: Vec<proto::Collaborator>,
622}
623
624pub struct Project {
625 pub collaborators: Vec<ProjectCollaborator>,
626 pub worktrees: BTreeMap<u64, Worktree>,
627 pub language_servers: Vec<proto::LanguageServer>,
628}
629
630pub struct ProjectCollaborator {
631 pub connection_id: ConnectionId,
632 pub user_id: UserId,
633 pub replica_id: ReplicaId,
634 pub is_host: bool,
635}
636
637impl ProjectCollaborator {
638 pub fn to_proto(&self) -> proto::Collaborator {
639 proto::Collaborator {
640 peer_id: Some(self.connection_id.into()),
641 replica_id: self.replica_id.0 as u32,
642 user_id: self.user_id.to_proto(),
643 }
644 }
645}
646
647#[derive(Debug)]
648pub struct LeftProject {
649 pub id: ProjectId,
650 pub host_user_id: UserId,
651 pub host_connection_id: ConnectionId,
652 pub connection_ids: Vec<ConnectionId>,
653}
654
655pub struct Worktree {
656 pub id: u64,
657 pub abs_path: String,
658 pub root_name: String,
659 pub visible: bool,
660 pub entries: Vec<proto::Entry>,
661 pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
662 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
663 pub settings_files: Vec<WorktreeSettingsFile>,
664 pub scan_id: u64,
665 pub completed_scan_id: u64,
666}
667
668#[derive(Debug)]
669pub struct WorktreeSettingsFile {
670 pub path: String,
671 pub content: String,
672}