db.rs

  1#[cfg(test)]
  2pub mod tests;
  3
  4#[cfg(test)]
  5pub use tests::TestDb;
  6
  7mod ids;
  8mod queries;
  9mod tables;
 10
 11use crate::{executor::Executor, Error, Result};
 12use anyhow::anyhow;
 13use collections::{BTreeMap, HashMap, HashSet};
 14use dashmap::DashMap;
 15use futures::StreamExt;
 16use rand::{prelude::StdRng, Rng, SeedableRng};
 17use rpc::{
 18    proto::{self},
 19    ConnectionId,
 20};
 21use sea_orm::{
 22    entity::prelude::*,
 23    sea_query::{Alias, Expr, OnConflict},
 24    ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
 25    FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
 26    TransactionTrait,
 27};
 28use serde::{Deserialize, Serialize};
 29use sqlx::{
 30    migrate::{Migrate, Migration, MigrationSource},
 31    Connection,
 32};
 33use std::{
 34    fmt::Write as _,
 35    future::Future,
 36    marker::PhantomData,
 37    ops::{Deref, DerefMut},
 38    path::Path,
 39    rc::Rc,
 40    sync::Arc,
 41    time::Duration,
 42};
 43use tables::*;
 44use tokio::sync::{Mutex, OwnedMutexGuard};
 45
 46pub use ids::*;
 47pub use queries::contributors::ContributorSelector;
 48pub use sea_orm::ConnectOptions;
 49pub use tables::user::Model as User;
 50
 51/// Database gives you a handle that lets you access the database.
 52/// It handles pooling internally.
 53pub struct Database {
 54    options: ConnectOptions,
 55    pool: DatabaseConnection,
 56    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 57    rng: Mutex<StdRng>,
 58    executor: Executor,
 59    notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
 60    notification_kinds_by_name: HashMap<String, NotificationKindId>,
 61    #[cfg(test)]
 62    runtime: Option<tokio::runtime::Runtime>,
 63}
 64
 65// The `Database` type has so many methods that its impl blocks are split into
 66// separate files in the `queries` folder.
 67impl Database {
 68    /// Connects to the database with the given options
 69    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 70        sqlx::any::install_default_drivers();
 71        Ok(Self {
 72            options: options.clone(),
 73            pool: sea_orm::Database::connect(options).await?,
 74            rooms: DashMap::with_capacity(16384),
 75            rng: Mutex::new(StdRng::seed_from_u64(0)),
 76            notification_kinds_by_id: HashMap::default(),
 77            notification_kinds_by_name: HashMap::default(),
 78            executor,
 79            #[cfg(test)]
 80            runtime: None,
 81        })
 82    }
 83
 84    #[cfg(test)]
 85    pub fn reset(&self) {
 86        self.rooms.clear();
 87    }
 88
 89    /// Runs the database migrations.
 90    pub async fn migrate(
 91        &self,
 92        migrations_path: &Path,
 93        ignore_checksum_mismatch: bool,
 94    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 95        let migrations = MigrationSource::resolve(migrations_path)
 96            .await
 97            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 98
 99        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
100
101        connection.ensure_migrations_table().await?;
102        let applied_migrations: HashMap<_, _> = connection
103            .list_applied_migrations()
104            .await?
105            .into_iter()
106            .map(|m| (m.version, m))
107            .collect();
108
109        let mut new_migrations = Vec::new();
110        for migration in migrations {
111            match applied_migrations.get(&migration.version) {
112                Some(applied_migration) => {
113                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
114                    {
115                        Err(anyhow!(
116                            "checksum mismatch for applied migration {}",
117                            migration.description
118                        ))?;
119                    }
120                }
121                None => {
122                    let elapsed = connection.apply(&migration).await?;
123                    new_migrations.push((migration, elapsed));
124                }
125            }
126        }
127
128        Ok(new_migrations)
129    }
130
131    /// Initializes static data that resides in the database by upserting it.
132    pub async fn initialize_static_data(&mut self) -> Result<()> {
133        self.initialize_notification_kinds().await?;
134        Ok(())
135    }
136
137    /// Transaction runs things in a transaction. If you want to call other methods
138    /// and pass the transaction around you need to reborrow the transaction at each
139    /// call site with: `&*tx`.
140    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
141    where
142        F: Send + Fn(TransactionHandle) -> Fut,
143        Fut: Send + Future<Output = Result<T>>,
144    {
145        let body = async {
146            let mut i = 0;
147            loop {
148                let (tx, result) = self.with_transaction(&f).await?;
149                match result {
150                    Ok(result) => match tx.commit().await.map_err(Into::into) {
151                        Ok(()) => return Ok(result),
152                        Err(error) => {
153                            if !self.retry_on_serialization_error(&error, i).await {
154                                return Err(error);
155                            }
156                        }
157                    },
158                    Err(error) => {
159                        tx.rollback().await?;
160                        if !self.retry_on_serialization_error(&error, i).await {
161                            return Err(error);
162                        }
163                    }
164                }
165                i += 1;
166            }
167        };
168
169        self.run(body).await
170    }
171
172    pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
173    where
174        F: Send + Fn(TransactionHandle) -> Fut,
175        Fut: Send + Future<Output = Result<T>>,
176    {
177        let body = async {
178            let (tx, result) = self.with_weak_transaction(&f).await?;
179            match result {
180                Ok(result) => match tx.commit().await.map_err(Into::into) {
181                    Ok(()) => return Ok(result),
182                    Err(error) => {
183                        return Err(error);
184                    }
185                },
186                Err(error) => {
187                    tx.rollback().await?;
188                    return Err(error);
189                }
190            }
191        };
192
193        self.run(body).await
194    }
195
196    /// The same as room_transaction, but if you need to only optionally return a Room.
197    async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
198    where
199        F: Send + Fn(TransactionHandle) -> Fut,
200        Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
201    {
202        let body = async {
203            let mut i = 0;
204            loop {
205                let (tx, result) = self.with_transaction(&f).await?;
206                match result {
207                    Ok(Some((room_id, data))) => {
208                        let lock = self.rooms.entry(room_id).or_default().clone();
209                        let _guard = lock.lock_owned().await;
210                        match tx.commit().await.map_err(Into::into) {
211                            Ok(()) => {
212                                return Ok(Some(RoomGuard {
213                                    data,
214                                    _guard,
215                                    _not_send: PhantomData,
216                                }));
217                            }
218                            Err(error) => {
219                                if !self.retry_on_serialization_error(&error, i).await {
220                                    return Err(error);
221                                }
222                            }
223                        }
224                    }
225                    Ok(None) => match tx.commit().await.map_err(Into::into) {
226                        Ok(()) => return Ok(None),
227                        Err(error) => {
228                            if !self.retry_on_serialization_error(&error, i).await {
229                                return Err(error);
230                            }
231                        }
232                    },
233                    Err(error) => {
234                        tx.rollback().await?;
235                        if !self.retry_on_serialization_error(&error, i).await {
236                            return Err(error);
237                        }
238                    }
239                }
240                i += 1;
241            }
242        };
243
244        self.run(body).await
245    }
246
247    /// room_transaction runs the block in a transaction. It returns a RoomGuard, that keeps
248    /// the database locked until it is dropped. This ensures that updates sent to clients are
249    /// properly serialized with respect to database changes.
250    async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
251    where
252        F: Send + Fn(TransactionHandle) -> Fut,
253        Fut: Send + Future<Output = Result<T>>,
254    {
255        let body = async {
256            let mut i = 0;
257            loop {
258                let lock = self.rooms.entry(room_id).or_default().clone();
259                let _guard = lock.lock_owned().await;
260                let (tx, result) = self.with_transaction(&f).await?;
261                match result {
262                    Ok(data) => match tx.commit().await.map_err(Into::into) {
263                        Ok(()) => {
264                            return Ok(RoomGuard {
265                                data,
266                                _guard,
267                                _not_send: PhantomData,
268                            });
269                        }
270                        Err(error) => {
271                            if !self.retry_on_serialization_error(&error, i).await {
272                                return Err(error);
273                            }
274                        }
275                    },
276                    Err(error) => {
277                        tx.rollback().await?;
278                        if !self.retry_on_serialization_error(&error, i).await {
279                            return Err(error);
280                        }
281                    }
282                }
283                i += 1;
284            }
285        };
286
287        self.run(body).await
288    }
289
290    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
291    where
292        F: Send + Fn(TransactionHandle) -> Fut,
293        Fut: Send + Future<Output = Result<T>>,
294    {
295        let tx = self
296            .pool
297            .begin_with_config(Some(IsolationLevel::Serializable), None)
298            .await?;
299
300        let mut tx = Arc::new(Some(tx));
301        let result = f(TransactionHandle(tx.clone())).await;
302        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
303            return Err(anyhow!(
304                "couldn't complete transaction because it's still in use"
305            ))?;
306        };
307
308        Ok((tx, result))
309    }
310
311    async fn with_weak_transaction<F, Fut, T>(
312        &self,
313        f: &F,
314    ) -> Result<(DatabaseTransaction, Result<T>)>
315    where
316        F: Send + Fn(TransactionHandle) -> Fut,
317        Fut: Send + Future<Output = Result<T>>,
318    {
319        let tx = self
320            .pool
321            .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
322            .await?;
323
324        let mut tx = Arc::new(Some(tx));
325        let result = f(TransactionHandle(tx.clone())).await;
326        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
327            return Err(anyhow!(
328                "couldn't complete transaction because it's still in use"
329            ))?;
330        };
331
332        Ok((tx, result))
333    }
334
335    async fn run<F, T>(&self, future: F) -> Result<T>
336    where
337        F: Future<Output = Result<T>>,
338    {
339        #[cfg(test)]
340        {
341            if let Executor::Deterministic(executor) = &self.executor {
342                executor.simulate_random_delay().await;
343            }
344
345            self.runtime.as_ref().unwrap().block_on(future)
346        }
347
348        #[cfg(not(test))]
349        {
350            future.await
351        }
352    }
353
354    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: usize) -> bool {
355        // If the error is due to a failure to serialize concurrent transactions, then retry
356        // this transaction after a delay. With each subsequent retry, double the delay duration.
357        // Also vary the delay randomly in order to ensure different database connections retry
358        // at different times.
359        const SLEEPS: [f32; 10] = [10., 20., 40., 80., 160., 320., 640., 1280., 2560., 5120.];
360        if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() {
361            let base_delay = SLEEPS[prev_attempt_count];
362            let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
363            log::info!(
364                "retrying transaction after serialization error. delay: {} ms.",
365                randomized_delay
366            );
367            self.executor
368                .sleep(Duration::from_millis(randomized_delay as u64))
369                .await;
370            true
371        } else {
372            false
373        }
374    }
375}
376
377fn is_serialization_error(error: &Error) -> bool {
378    const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
379    match error {
380        Error::Database(
381            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
382            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
383        ) if error
384            .as_database_error()
385            .and_then(|error| error.code())
386            .as_deref()
387            == Some(SERIALIZATION_FAILURE_CODE) =>
388        {
389            true
390        }
391        _ => false,
392    }
393}
394
395/// A handle to a [`DatabaseTransaction`].
396pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
397
398impl Deref for TransactionHandle {
399    type Target = DatabaseTransaction;
400
401    fn deref(&self) -> &Self::Target {
402        self.0.as_ref().as_ref().unwrap()
403    }
404}
405
406/// [`RoomGuard`] keeps a database transaction alive until it is dropped.
407/// so that updates to rooms are serialized.
408pub struct RoomGuard<T> {
409    data: T,
410    _guard: OwnedMutexGuard<()>,
411    _not_send: PhantomData<Rc<()>>,
412}
413
414impl<T> Deref for RoomGuard<T> {
415    type Target = T;
416
417    fn deref(&self) -> &T {
418        &self.data
419    }
420}
421
422impl<T> DerefMut for RoomGuard<T> {
423    fn deref_mut(&mut self) -> &mut T {
424        &mut self.data
425    }
426}
427
428impl<T> RoomGuard<T> {
429    /// Returns the inner value of the guard.
430    pub fn into_inner(self) -> T {
431        self.data
432    }
433}
434
435#[derive(Clone, Debug, PartialEq, Eq)]
436pub enum Contact {
437    Accepted { user_id: UserId, busy: bool },
438    Outgoing { user_id: UserId },
439    Incoming { user_id: UserId },
440}
441
442impl Contact {
443    pub fn user_id(&self) -> UserId {
444        match self {
445            Contact::Accepted { user_id, .. } => *user_id,
446            Contact::Outgoing { user_id } => *user_id,
447            Contact::Incoming { user_id, .. } => *user_id,
448        }
449    }
450}
451
452pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
453
454pub struct CreatedChannelMessage {
455    pub message_id: MessageId,
456    pub participant_connection_ids: Vec<ConnectionId>,
457    pub channel_members: Vec<UserId>,
458    pub notifications: NotificationBatch,
459}
460
461#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
462pub struct Invite {
463    pub email_address: String,
464    pub email_confirmation_code: String,
465}
466
467#[derive(Clone, Debug, Deserialize)]
468pub struct NewSignup {
469    pub email_address: String,
470    pub platform_mac: bool,
471    pub platform_windows: bool,
472    pub platform_linux: bool,
473    pub editor_features: Vec<String>,
474    pub programming_languages: Vec<String>,
475    pub device_id: Option<String>,
476    pub added_to_mailing_list: bool,
477    pub created_at: Option<DateTime>,
478}
479
480#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
481pub struct WaitlistSummary {
482    pub count: i64,
483    pub linux_count: i64,
484    pub mac_count: i64,
485    pub windows_count: i64,
486    pub unknown_count: i64,
487}
488
489/// The parameters to create a new user.
490#[derive(Debug, Serialize, Deserialize)]
491pub struct NewUserParams {
492    pub github_login: String,
493    pub github_user_id: i32,
494}
495
496/// The result of creating a new user.
497#[derive(Debug)]
498pub struct NewUserResult {
499    pub user_id: UserId,
500    pub metrics_id: String,
501    pub inviting_user_id: Option<UserId>,
502    pub signup_device_id: Option<String>,
503}
504
505/// The result of moving a channel.
506#[derive(Debug)]
507pub struct MoveChannelResult {
508    pub previous_participants: Vec<ChannelMember>,
509    pub descendent_ids: Vec<ChannelId>,
510}
511
512/// The result of renaming a channel.
513#[derive(Debug)]
514pub struct RenameChannelResult {
515    pub channel: Channel,
516    pub participants_to_update: HashMap<UserId, Channel>,
517}
518
519/// The result of creating a channel.
520#[derive(Debug)]
521pub struct CreateChannelResult {
522    pub channel: Channel,
523    pub participants_to_update: Vec<(UserId, ChannelsForUser)>,
524}
525
526/// The result of setting a channel's visibility.
527#[derive(Debug)]
528pub struct SetChannelVisibilityResult {
529    pub participants_to_update: HashMap<UserId, ChannelsForUser>,
530    pub participants_to_remove: HashSet<UserId>,
531    pub channels_to_remove: Vec<ChannelId>,
532}
533
534/// The result of updating a channel membership.
535#[derive(Debug)]
536pub struct MembershipUpdated {
537    pub channel_id: ChannelId,
538    pub new_channels: ChannelsForUser,
539    pub removed_channels: Vec<ChannelId>,
540}
541
542/// The result of setting a member's role.
543#[derive(Debug)]
544pub enum SetMemberRoleResult {
545    InviteUpdated(Channel),
546    MembershipUpdated(MembershipUpdated),
547}
548
549/// The result of inviting a member to a channel.
550#[derive(Debug)]
551pub struct InviteMemberResult {
552    pub channel: Channel,
553    pub notifications: NotificationBatch,
554}
555
556#[derive(Debug)]
557pub struct RespondToChannelInvite {
558    pub membership_update: Option<MembershipUpdated>,
559    pub notifications: NotificationBatch,
560}
561
562#[derive(Debug)]
563pub struct RemoveChannelMemberResult {
564    pub membership_update: MembershipUpdated,
565    pub notification_id: Option<NotificationId>,
566}
567
568#[derive(Debug, PartialEq, Eq, Hash)]
569pub struct Channel {
570    pub id: ChannelId,
571    pub name: String,
572    pub visibility: ChannelVisibility,
573    pub role: ChannelRole,
574    /// parent_path is the channel ids from the root to this one (not including this one)
575    pub parent_path: Vec<ChannelId>,
576}
577
578impl Channel {
579    fn from_model(value: channel::Model, role: ChannelRole) -> Self {
580        Channel {
581            id: value.id,
582            visibility: value.visibility,
583            name: value.clone().name,
584            role,
585            parent_path: value.ancestors().collect(),
586        }
587    }
588
589    pub fn to_proto(&self) -> proto::Channel {
590        proto::Channel {
591            id: self.id.to_proto(),
592            name: self.name.clone(),
593            visibility: self.visibility.into(),
594            role: self.role.into(),
595            parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(),
596        }
597    }
598}
599
600#[derive(Debug, PartialEq, Eq, Hash)]
601pub struct ChannelMember {
602    pub role: ChannelRole,
603    pub user_id: UserId,
604    pub kind: proto::channel_member::Kind,
605}
606
607impl ChannelMember {
608    pub fn to_proto(&self) -> proto::ChannelMember {
609        proto::ChannelMember {
610            role: self.role.into(),
611            user_id: self.user_id.to_proto(),
612            kind: self.kind.into(),
613        }
614    }
615}
616
617#[derive(Debug, PartialEq)]
618pub struct ChannelsForUser {
619    pub channels: Vec<Channel>,
620    pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
621    pub unseen_buffer_changes: Vec<proto::UnseenChannelBufferChange>,
622    pub channel_messages: Vec<proto::UnseenChannelMessage>,
623}
624
625#[derive(Debug)]
626pub struct RejoinedChannelBuffer {
627    pub buffer: proto::RejoinedChannelBuffer,
628    pub old_connection_id: ConnectionId,
629}
630
631#[derive(Clone)]
632pub struct JoinRoom {
633    pub room: proto::Room,
634    pub channel_id: Option<ChannelId>,
635    pub channel_members: Vec<UserId>,
636}
637
638pub struct RejoinedRoom {
639    pub room: proto::Room,
640    pub rejoined_projects: Vec<RejoinedProject>,
641    pub reshared_projects: Vec<ResharedProject>,
642    pub channel_id: Option<ChannelId>,
643    pub channel_members: Vec<UserId>,
644}
645
646pub struct ResharedProject {
647    pub id: ProjectId,
648    pub old_connection_id: ConnectionId,
649    pub collaborators: Vec<ProjectCollaborator>,
650    pub worktrees: Vec<proto::WorktreeMetadata>,
651}
652
653pub struct RejoinedProject {
654    pub id: ProjectId,
655    pub old_connection_id: ConnectionId,
656    pub collaborators: Vec<ProjectCollaborator>,
657    pub worktrees: Vec<RejoinedWorktree>,
658    pub language_servers: Vec<proto::LanguageServer>,
659}
660
661#[derive(Debug)]
662pub struct RejoinedWorktree {
663    pub id: u64,
664    pub abs_path: String,
665    pub root_name: String,
666    pub visible: bool,
667    pub updated_entries: Vec<proto::Entry>,
668    pub removed_entries: Vec<u64>,
669    pub updated_repositories: Vec<proto::RepositoryEntry>,
670    pub removed_repositories: Vec<u64>,
671    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
672    pub settings_files: Vec<WorktreeSettingsFile>,
673    pub scan_id: u64,
674    pub completed_scan_id: u64,
675}
676
677pub struct LeftRoom {
678    pub room: proto::Room,
679    pub channel_id: Option<ChannelId>,
680    pub channel_members: Vec<UserId>,
681    pub left_projects: HashMap<ProjectId, LeftProject>,
682    pub canceled_calls_to_user_ids: Vec<UserId>,
683    pub deleted: bool,
684}
685
686pub struct RefreshedRoom {
687    pub room: proto::Room,
688    pub channel_id: Option<ChannelId>,
689    pub channel_members: Vec<UserId>,
690    pub stale_participant_user_ids: Vec<UserId>,
691    pub canceled_calls_to_user_ids: Vec<UserId>,
692}
693
694pub struct RefreshedChannelBuffer {
695    pub connection_ids: Vec<ConnectionId>,
696    pub collaborators: Vec<proto::Collaborator>,
697}
698
699pub struct Project {
700    pub collaborators: Vec<ProjectCollaborator>,
701    pub worktrees: BTreeMap<u64, Worktree>,
702    pub language_servers: Vec<proto::LanguageServer>,
703}
704
705pub struct ProjectCollaborator {
706    pub connection_id: ConnectionId,
707    pub user_id: UserId,
708    pub replica_id: ReplicaId,
709    pub is_host: bool,
710}
711
712impl ProjectCollaborator {
713    pub fn to_proto(&self) -> proto::Collaborator {
714        proto::Collaborator {
715            peer_id: Some(self.connection_id.into()),
716            replica_id: self.replica_id.0 as u32,
717            user_id: self.user_id.to_proto(),
718        }
719    }
720}
721
722#[derive(Debug)]
723pub struct LeftProject {
724    pub id: ProjectId,
725    pub host_user_id: UserId,
726    pub host_connection_id: ConnectionId,
727    pub connection_ids: Vec<ConnectionId>,
728}
729
730pub struct Worktree {
731    pub id: u64,
732    pub abs_path: String,
733    pub root_name: String,
734    pub visible: bool,
735    pub entries: Vec<proto::Entry>,
736    pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
737    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
738    pub settings_files: Vec<WorktreeSettingsFile>,
739    pub scan_id: u64,
740    pub completed_scan_id: u64,
741}
742
743#[derive(Debug)]
744pub struct WorktreeSettingsFile {
745    pub path: String,
746    pub content: String,
747}