db.rs

  1#[cfg(test)]
  2pub mod tests;
  3
  4#[cfg(test)]
  5pub use tests::TestDb;
  6
  7mod ids;
  8mod queries;
  9mod tables;
 10
 11use crate::{executor::Executor, Error, Result};
 12use anyhow::anyhow;
 13use collections::{BTreeMap, HashMap, HashSet};
 14use dashmap::DashMap;
 15use futures::StreamExt;
 16use rand::{prelude::StdRng, Rng, SeedableRng};
 17use rpc::{
 18    proto::{self},
 19    ConnectionId,
 20};
 21use sea_orm::{
 22    entity::prelude::*,
 23    sea_query::{Alias, Expr, OnConflict},
 24    ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
 25    FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
 26    TransactionTrait,
 27};
 28use serde::{Deserialize, Serialize};
 29use sqlx::{
 30    migrate::{Migrate, Migration, MigrationSource},
 31    Connection,
 32};
 33use std::{
 34    fmt::Write as _,
 35    future::Future,
 36    marker::PhantomData,
 37    ops::{Deref, DerefMut},
 38    path::Path,
 39    rc::Rc,
 40    sync::Arc,
 41    time::Duration,
 42};
 43use tables::*;
 44use tokio::sync::{Mutex, OwnedMutexGuard};
 45
 46pub use ids::*;
 47pub use sea_orm::ConnectOptions;
 48pub use tables::user::Model as User;
 49
 50pub struct Database {
 51    options: ConnectOptions,
 52    pool: DatabaseConnection,
 53    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 54    rng: Mutex<StdRng>,
 55    executor: Executor,
 56    notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
 57    notification_kinds_by_name: HashMap<String, NotificationKindId>,
 58    #[cfg(test)]
 59    runtime: Option<tokio::runtime::Runtime>,
 60}
 61
 62// The `Database` type has so many methods that its impl blocks are split into
 63// separate files in the `queries` folder.
 64impl Database {
 65    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 66        sqlx::any::install_default_drivers();
 67        Ok(Self {
 68            options: options.clone(),
 69            pool: sea_orm::Database::connect(options).await?,
 70            rooms: DashMap::with_capacity(16384),
 71            rng: Mutex::new(StdRng::seed_from_u64(0)),
 72            notification_kinds_by_id: HashMap::default(),
 73            notification_kinds_by_name: HashMap::default(),
 74            executor,
 75            #[cfg(test)]
 76            runtime: None,
 77        })
 78    }
 79
 80    #[cfg(test)]
 81    pub fn reset(&self) {
 82        self.rooms.clear();
 83    }
 84
 85    pub async fn migrate(
 86        &self,
 87        migrations_path: &Path,
 88        ignore_checksum_mismatch: bool,
 89    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 90        let migrations = MigrationSource::resolve(migrations_path)
 91            .await
 92            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 93
 94        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
 95
 96        connection.ensure_migrations_table().await?;
 97        let applied_migrations: HashMap<_, _> = connection
 98            .list_applied_migrations()
 99            .await?
100            .into_iter()
101            .map(|m| (m.version, m))
102            .collect();
103
104        let mut new_migrations = Vec::new();
105        for migration in migrations {
106            match applied_migrations.get(&migration.version) {
107                Some(applied_migration) => {
108                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
109                    {
110                        Err(anyhow!(
111                            "checksum mismatch for applied migration {}",
112                            migration.description
113                        ))?;
114                    }
115                }
116                None => {
117                    let elapsed = connection.apply(&migration).await?;
118                    new_migrations.push((migration, elapsed));
119                }
120            }
121        }
122
123        Ok(new_migrations)
124    }
125
126    pub async fn initialize_static_data(&mut self) -> Result<()> {
127        self.initialize_notification_kinds().await?;
128        Ok(())
129    }
130
131    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
132    where
133        F: Send + Fn(TransactionHandle) -> Fut,
134        Fut: Send + Future<Output = Result<T>>,
135    {
136        let body = async {
137            let mut i = 0;
138            loop {
139                let (tx, result) = self.with_transaction(&f).await?;
140                match result {
141                    Ok(result) => match tx.commit().await.map_err(Into::into) {
142                        Ok(()) => return Ok(result),
143                        Err(error) => {
144                            if !self.retry_on_serialization_error(&error, i).await {
145                                return Err(error);
146                            }
147                        }
148                    },
149                    Err(error) => {
150                        tx.rollback().await?;
151                        if !self.retry_on_serialization_error(&error, i).await {
152                            return Err(error);
153                        }
154                    }
155                }
156                i += 1;
157            }
158        };
159
160        self.run(body).await
161    }
162
163    async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
164    where
165        F: Send + Fn(TransactionHandle) -> Fut,
166        Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
167    {
168        let body = async {
169            let mut i = 0;
170            loop {
171                let (tx, result) = self.with_transaction(&f).await?;
172                match result {
173                    Ok(Some((room_id, data))) => {
174                        let lock = self.rooms.entry(room_id).or_default().clone();
175                        let _guard = lock.lock_owned().await;
176                        match tx.commit().await.map_err(Into::into) {
177                            Ok(()) => {
178                                return Ok(Some(RoomGuard {
179                                    data,
180                                    _guard,
181                                    _not_send: PhantomData,
182                                }));
183                            }
184                            Err(error) => {
185                                if !self.retry_on_serialization_error(&error, i).await {
186                                    return Err(error);
187                                }
188                            }
189                        }
190                    }
191                    Ok(None) => match tx.commit().await.map_err(Into::into) {
192                        Ok(()) => return Ok(None),
193                        Err(error) => {
194                            if !self.retry_on_serialization_error(&error, i).await {
195                                return Err(error);
196                            }
197                        }
198                    },
199                    Err(error) => {
200                        tx.rollback().await?;
201                        if !self.retry_on_serialization_error(&error, i).await {
202                            return Err(error);
203                        }
204                    }
205                }
206                i += 1;
207            }
208        };
209
210        self.run(body).await
211    }
212
213    async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
214    where
215        F: Send + Fn(TransactionHandle) -> Fut,
216        Fut: Send + Future<Output = Result<T>>,
217    {
218        let body = async {
219            let mut i = 0;
220            loop {
221                let lock = self.rooms.entry(room_id).or_default().clone();
222                let _guard = lock.lock_owned().await;
223                let (tx, result) = self.with_transaction(&f).await?;
224                match result {
225                    Ok(data) => match tx.commit().await.map_err(Into::into) {
226                        Ok(()) => {
227                            return Ok(RoomGuard {
228                                data,
229                                _guard,
230                                _not_send: PhantomData,
231                            });
232                        }
233                        Err(error) => {
234                            if !self.retry_on_serialization_error(&error, i).await {
235                                return Err(error);
236                            }
237                        }
238                    },
239                    Err(error) => {
240                        tx.rollback().await?;
241                        if !self.retry_on_serialization_error(&error, i).await {
242                            return Err(error);
243                        }
244                    }
245                }
246                i += 1;
247            }
248        };
249
250        self.run(body).await
251    }
252
253    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
254    where
255        F: Send + Fn(TransactionHandle) -> Fut,
256        Fut: Send + Future<Output = Result<T>>,
257    {
258        let tx = self
259            .pool
260            .begin_with_config(Some(IsolationLevel::Serializable), None)
261            .await?;
262
263        let mut tx = Arc::new(Some(tx));
264        let result = f(TransactionHandle(tx.clone())).await;
265        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
266            return Err(anyhow!(
267                "couldn't complete transaction because it's still in use"
268            ))?;
269        };
270
271        Ok((tx, result))
272    }
273
274    async fn run<F, T>(&self, future: F) -> Result<T>
275    where
276        F: Future<Output = Result<T>>,
277    {
278        #[cfg(test)]
279        {
280            if let Executor::Deterministic(executor) = &self.executor {
281                executor.simulate_random_delay().await;
282            }
283
284            self.runtime.as_ref().unwrap().block_on(future)
285        }
286
287        #[cfg(not(test))]
288        {
289            future.await
290        }
291    }
292
293    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
294        // If the error is due to a failure to serialize concurrent transactions, then retry
295        // this transaction after a delay. With each subsequent retry, double the delay duration.
296        // Also vary the delay randomly in order to ensure different database connections retry
297        // at different times.
298        if is_serialization_error(error) {
299            let base_delay = 4_u64 << prev_attempt_count.min(16);
300            let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
301            log::info!(
302                "retrying transaction after serialization error. delay: {} ms.",
303                randomized_delay
304            );
305            self.executor
306                .sleep(Duration::from_millis(randomized_delay as u64))
307                .await;
308            true
309        } else {
310            false
311        }
312    }
313}
314
315fn is_serialization_error(error: &Error) -> bool {
316    const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
317    match error {
318        Error::Database(
319            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
320            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
321        ) if error
322            .as_database_error()
323            .and_then(|error| error.code())
324            .as_deref()
325            == Some(SERIALIZATION_FAILURE_CODE) =>
326        {
327            true
328        }
329        _ => false,
330    }
331}
332
333pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
334
335impl Deref for TransactionHandle {
336    type Target = DatabaseTransaction;
337
338    fn deref(&self) -> &Self::Target {
339        self.0.as_ref().as_ref().unwrap()
340    }
341}
342
343pub struct RoomGuard<T> {
344    data: T,
345    _guard: OwnedMutexGuard<()>,
346    _not_send: PhantomData<Rc<()>>,
347}
348
349impl<T> Deref for RoomGuard<T> {
350    type Target = T;
351
352    fn deref(&self) -> &T {
353        &self.data
354    }
355}
356
357impl<T> DerefMut for RoomGuard<T> {
358    fn deref_mut(&mut self) -> &mut T {
359        &mut self.data
360    }
361}
362
363impl<T> RoomGuard<T> {
364    pub fn into_inner(self) -> T {
365        self.data
366    }
367}
368
369#[derive(Clone, Debug, PartialEq, Eq)]
370pub enum Contact {
371    Accepted { user_id: UserId, busy: bool },
372    Outgoing { user_id: UserId },
373    Incoming { user_id: UserId },
374}
375
376impl Contact {
377    pub fn user_id(&self) -> UserId {
378        match self {
379            Contact::Accepted { user_id, .. } => *user_id,
380            Contact::Outgoing { user_id } => *user_id,
381            Contact::Incoming { user_id, .. } => *user_id,
382        }
383    }
384}
385
386pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
387
388pub struct CreatedChannelMessage {
389    pub message_id: MessageId,
390    pub participant_connection_ids: Vec<ConnectionId>,
391    pub channel_members: Vec<UserId>,
392    pub notifications: NotificationBatch,
393}
394
395#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
396pub struct Invite {
397    pub email_address: String,
398    pub email_confirmation_code: String,
399}
400
401#[derive(Clone, Debug, Deserialize)]
402pub struct NewSignup {
403    pub email_address: String,
404    pub platform_mac: bool,
405    pub platform_windows: bool,
406    pub platform_linux: bool,
407    pub editor_features: Vec<String>,
408    pub programming_languages: Vec<String>,
409    pub device_id: Option<String>,
410    pub added_to_mailing_list: bool,
411    pub created_at: Option<DateTime>,
412}
413
414#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
415pub struct WaitlistSummary {
416    pub count: i64,
417    pub linux_count: i64,
418    pub mac_count: i64,
419    pub windows_count: i64,
420    pub unknown_count: i64,
421}
422
423#[derive(Debug, Serialize, Deserialize)]
424pub struct NewUserParams {
425    pub github_login: String,
426    pub github_user_id: i32,
427}
428
429#[derive(Debug)]
430pub struct NewUserResult {
431    pub user_id: UserId,
432    pub metrics_id: String,
433    pub inviting_user_id: Option<UserId>,
434    pub signup_device_id: Option<String>,
435}
436
437#[derive(Debug)]
438pub struct MoveChannelResult {
439    pub participants_to_update: HashMap<UserId, ChannelsForUser>,
440    pub participants_to_remove: HashSet<UserId>,
441    pub moved_channels: HashSet<ChannelId>,
442}
443
444#[derive(Debug)]
445pub struct RenameChannelResult {
446    pub channel: Channel,
447    pub participants_to_update: HashMap<UserId, Channel>,
448}
449
450#[derive(Debug)]
451pub struct CreateChannelResult {
452    pub channel: Channel,
453    pub participants_to_update: Vec<(UserId, ChannelsForUser)>,
454}
455
456#[derive(Debug)]
457pub struct SetChannelVisibilityResult {
458    pub participants_to_update: HashMap<UserId, ChannelsForUser>,
459    pub participants_to_remove: HashSet<UserId>,
460    pub channels_to_remove: Vec<ChannelId>,
461}
462
463#[derive(Debug)]
464pub struct MembershipUpdated {
465    pub channel_id: ChannelId,
466    pub new_channels: ChannelsForUser,
467    pub removed_channels: Vec<ChannelId>,
468}
469
470#[derive(Debug)]
471pub enum SetMemberRoleResult {
472    InviteUpdated(Channel),
473    MembershipUpdated(MembershipUpdated),
474}
475
476#[derive(Debug)]
477pub struct InviteMemberResult {
478    pub channel: Channel,
479    pub notifications: NotificationBatch,
480}
481
482#[derive(Debug)]
483pub struct RespondToChannelInvite {
484    pub membership_update: Option<MembershipUpdated>,
485    pub notifications: NotificationBatch,
486}
487
488#[derive(Debug)]
489pub struct RemoveChannelMemberResult {
490    pub membership_update: MembershipUpdated,
491    pub notification_id: Option<NotificationId>,
492}
493
494#[derive(Debug, PartialEq, Eq, Hash)]
495pub struct Channel {
496    pub id: ChannelId,
497    pub name: String,
498    pub visibility: ChannelVisibility,
499    pub role: ChannelRole,
500    pub parent_path: Vec<ChannelId>,
501}
502
503impl Channel {
504    fn from_model(value: channel::Model, role: ChannelRole) -> Self {
505        Channel {
506            id: value.id,
507            visibility: value.visibility,
508            name: value.clone().name,
509            role,
510            parent_path: value.ancestors().collect(),
511        }
512    }
513
514    pub fn to_proto(&self) -> proto::Channel {
515        proto::Channel {
516            id: self.id.to_proto(),
517            name: self.name.clone(),
518            visibility: self.visibility.into(),
519            role: self.role.into(),
520            parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(),
521        }
522    }
523}
524
525#[derive(Debug, PartialEq, Eq, Hash)]
526pub struct ChannelMember {
527    pub role: ChannelRole,
528    pub user_id: UserId,
529    pub kind: proto::channel_member::Kind,
530}
531
532impl ChannelMember {
533    pub fn to_proto(&self) -> proto::ChannelMember {
534        proto::ChannelMember {
535            role: self.role.into(),
536            user_id: self.user_id.to_proto(),
537            kind: self.kind.into(),
538        }
539    }
540}
541
542#[derive(Debug, PartialEq)]
543pub struct ChannelsForUser {
544    pub channels: Vec<Channel>,
545    pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
546    pub unseen_buffer_changes: Vec<proto::UnseenChannelBufferChange>,
547    pub channel_messages: Vec<proto::UnseenChannelMessage>,
548}
549
550#[derive(Debug)]
551pub struct RejoinedChannelBuffer {
552    pub buffer: proto::RejoinedChannelBuffer,
553    pub old_connection_id: ConnectionId,
554}
555
556#[derive(Clone)]
557pub struct JoinRoom {
558    pub room: proto::Room,
559    pub channel_id: Option<ChannelId>,
560    pub channel_members: Vec<UserId>,
561}
562
563pub struct RejoinedRoom {
564    pub room: proto::Room,
565    pub rejoined_projects: Vec<RejoinedProject>,
566    pub reshared_projects: Vec<ResharedProject>,
567    pub channel_id: Option<ChannelId>,
568    pub channel_members: Vec<UserId>,
569}
570
571pub struct ResharedProject {
572    pub id: ProjectId,
573    pub old_connection_id: ConnectionId,
574    pub collaborators: Vec<ProjectCollaborator>,
575    pub worktrees: Vec<proto::WorktreeMetadata>,
576}
577
578pub struct RejoinedProject {
579    pub id: ProjectId,
580    pub old_connection_id: ConnectionId,
581    pub collaborators: Vec<ProjectCollaborator>,
582    pub worktrees: Vec<RejoinedWorktree>,
583    pub language_servers: Vec<proto::LanguageServer>,
584}
585
586#[derive(Debug)]
587pub struct RejoinedWorktree {
588    pub id: u64,
589    pub abs_path: String,
590    pub root_name: String,
591    pub visible: bool,
592    pub updated_entries: Vec<proto::Entry>,
593    pub removed_entries: Vec<u64>,
594    pub updated_repositories: Vec<proto::RepositoryEntry>,
595    pub removed_repositories: Vec<u64>,
596    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
597    pub settings_files: Vec<WorktreeSettingsFile>,
598    pub scan_id: u64,
599    pub completed_scan_id: u64,
600}
601
602pub struct LeftRoom {
603    pub room: proto::Room,
604    pub channel_id: Option<ChannelId>,
605    pub channel_members: Vec<UserId>,
606    pub left_projects: HashMap<ProjectId, LeftProject>,
607    pub canceled_calls_to_user_ids: Vec<UserId>,
608    pub deleted: bool,
609}
610
611pub struct RefreshedRoom {
612    pub room: proto::Room,
613    pub channel_id: Option<ChannelId>,
614    pub channel_members: Vec<UserId>,
615    pub stale_participant_user_ids: Vec<UserId>,
616    pub canceled_calls_to_user_ids: Vec<UserId>,
617}
618
619pub struct RefreshedChannelBuffer {
620    pub connection_ids: Vec<ConnectionId>,
621    pub collaborators: Vec<proto::Collaborator>,
622}
623
624pub struct Project {
625    pub collaborators: Vec<ProjectCollaborator>,
626    pub worktrees: BTreeMap<u64, Worktree>,
627    pub language_servers: Vec<proto::LanguageServer>,
628}
629
630pub struct ProjectCollaborator {
631    pub connection_id: ConnectionId,
632    pub user_id: UserId,
633    pub replica_id: ReplicaId,
634    pub is_host: bool,
635}
636
637impl ProjectCollaborator {
638    pub fn to_proto(&self) -> proto::Collaborator {
639        proto::Collaborator {
640            peer_id: Some(self.connection_id.into()),
641            replica_id: self.replica_id.0 as u32,
642            user_id: self.user_id.to_proto(),
643        }
644    }
645}
646
647#[derive(Debug)]
648pub struct LeftProject {
649    pub id: ProjectId,
650    pub host_user_id: UserId,
651    pub host_connection_id: ConnectionId,
652    pub connection_ids: Vec<ConnectionId>,
653}
654
655pub struct Worktree {
656    pub id: u64,
657    pub abs_path: String,
658    pub root_name: String,
659    pub visible: bool,
660    pub entries: Vec<proto::Entry>,
661    pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
662    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
663    pub settings_files: Vec<WorktreeSettingsFile>,
664    pub scan_id: u64,
665    pub completed_scan_id: u64,
666}
667
668#[derive(Debug)]
669pub struct WorktreeSettingsFile {
670    pub path: String,
671    pub content: String,
672}