db.rs

  1#[cfg(test)]
  2pub mod tests;
  3
  4#[cfg(test)]
  5pub use tests::TestDb;
  6
  7mod ids;
  8mod queries;
  9mod tables;
 10
 11use crate::{executor::Executor, Error, Result};
 12use anyhow::anyhow;
 13use collections::{BTreeMap, HashMap, HashSet};
 14use dashmap::DashMap;
 15use futures::StreamExt;
 16use rand::{prelude::StdRng, Rng, SeedableRng};
 17use rpc::{
 18    proto::{self},
 19    ConnectionId,
 20};
 21use sea_orm::{
 22    entity::prelude::*,
 23    sea_query::{Alias, Expr, OnConflict},
 24    ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
 25    FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
 26    TransactionTrait,
 27};
 28use serde::{Deserialize, Serialize};
 29use sqlx::{
 30    migrate::{Migrate, Migration, MigrationSource},
 31    Connection,
 32};
 33use std::{
 34    fmt::Write as _,
 35    future::Future,
 36    marker::PhantomData,
 37    ops::{Deref, DerefMut},
 38    path::Path,
 39    rc::Rc,
 40    sync::Arc,
 41    time::Duration,
 42};
 43use tables::*;
 44use tokio::sync::{Mutex, OwnedMutexGuard};
 45
 46pub use ids::*;
 47pub use sea_orm::ConnectOptions;
 48pub use tables::user::Model as User;
 49
 50/// Database gives you a handle that lets you access the database.
 51/// It handles pooling internally.
 52pub struct Database {
 53    options: ConnectOptions,
 54    pool: DatabaseConnection,
 55    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 56    rng: Mutex<StdRng>,
 57    executor: Executor,
 58    notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
 59    notification_kinds_by_name: HashMap<String, NotificationKindId>,
 60    #[cfg(test)]
 61    runtime: Option<tokio::runtime::Runtime>,
 62}
 63
 64// The `Database` type has so many methods that its impl blocks are split into
 65// separate files in the `queries` folder.
 66impl Database {
 67    /// Connects to the database with the given options
 68    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 69        sqlx::any::install_default_drivers();
 70        Ok(Self {
 71            options: options.clone(),
 72            pool: sea_orm::Database::connect(options).await?,
 73            rooms: DashMap::with_capacity(16384),
 74            rng: Mutex::new(StdRng::seed_from_u64(0)),
 75            notification_kinds_by_id: HashMap::default(),
 76            notification_kinds_by_name: HashMap::default(),
 77            executor,
 78            #[cfg(test)]
 79            runtime: None,
 80        })
 81    }
 82
 83    #[cfg(test)]
 84    pub fn reset(&self) {
 85        self.rooms.clear();
 86    }
 87
 88    /// Runs the database migrations.
 89    pub async fn migrate(
 90        &self,
 91        migrations_path: &Path,
 92        ignore_checksum_mismatch: bool,
 93    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 94        let migrations = MigrationSource::resolve(migrations_path)
 95            .await
 96            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 97
 98        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
 99
100        connection.ensure_migrations_table().await?;
101        let applied_migrations: HashMap<_, _> = connection
102            .list_applied_migrations()
103            .await?
104            .into_iter()
105            .map(|m| (m.version, m))
106            .collect();
107
108        let mut new_migrations = Vec::new();
109        for migration in migrations {
110            match applied_migrations.get(&migration.version) {
111                Some(applied_migration) => {
112                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
113                    {
114                        Err(anyhow!(
115                            "checksum mismatch for applied migration {}",
116                            migration.description
117                        ))?;
118                    }
119                }
120                None => {
121                    let elapsed = connection.apply(&migration).await?;
122                    new_migrations.push((migration, elapsed));
123                }
124            }
125        }
126
127        Ok(new_migrations)
128    }
129
130    /// Initializes static data that resides in the database by upserting it.
131    pub async fn initialize_static_data(&mut self) -> Result<()> {
132        self.initialize_notification_kinds().await?;
133        Ok(())
134    }
135
136    /// Transaction runs things in a transaction. If you want to call other methods
137    /// and pass the transaction around you need to reborrow the transaction at each
138    /// call site with: `&*tx`.
139    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
140    where
141        F: Send + Fn(TransactionHandle) -> Fut,
142        Fut: Send + Future<Output = Result<T>>,
143    {
144        let body = async {
145            let mut i = 0;
146            loop {
147                let (tx, result) = self.with_transaction(&f).await?;
148                match result {
149                    Ok(result) => match tx.commit().await.map_err(Into::into) {
150                        Ok(()) => return Ok(result),
151                        Err(error) => {
152                            if !self.retry_on_serialization_error(&error, i).await {
153                                return Err(error);
154                            }
155                        }
156                    },
157                    Err(error) => {
158                        tx.rollback().await?;
159                        if !self.retry_on_serialization_error(&error, i).await {
160                            return Err(error);
161                        }
162                    }
163                }
164                i += 1;
165            }
166        };
167
168        self.run(body).await
169    }
170
171    /// The same as room_transaction, but if you need to only optionally return a Room.
172    async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
173    where
174        F: Send + Fn(TransactionHandle) -> Fut,
175        Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
176    {
177        let body = async {
178            let mut i = 0;
179            loop {
180                let (tx, result) = self.with_transaction(&f).await?;
181                match result {
182                    Ok(Some((room_id, data))) => {
183                        let lock = self.rooms.entry(room_id).or_default().clone();
184                        let _guard = lock.lock_owned().await;
185                        match tx.commit().await.map_err(Into::into) {
186                            Ok(()) => {
187                                return Ok(Some(RoomGuard {
188                                    data,
189                                    _guard,
190                                    _not_send: PhantomData,
191                                }));
192                            }
193                            Err(error) => {
194                                if !self.retry_on_serialization_error(&error, i).await {
195                                    return Err(error);
196                                }
197                            }
198                        }
199                    }
200                    Ok(None) => match tx.commit().await.map_err(Into::into) {
201                        Ok(()) => return Ok(None),
202                        Err(error) => {
203                            if !self.retry_on_serialization_error(&error, i).await {
204                                return Err(error);
205                            }
206                        }
207                    },
208                    Err(error) => {
209                        tx.rollback().await?;
210                        if !self.retry_on_serialization_error(&error, i).await {
211                            return Err(error);
212                        }
213                    }
214                }
215                i += 1;
216            }
217        };
218
219        self.run(body).await
220    }
221
222    /// room_transaction runs the block in a transaction. It returns a RoomGuard, that keeps
223    /// the database locked until it is dropped. This ensures that updates sent to clients are
224    /// properly serialized with respect to database changes.
225    async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
226    where
227        F: Send + Fn(TransactionHandle) -> Fut,
228        Fut: Send + Future<Output = Result<T>>,
229    {
230        let body = async {
231            let mut i = 0;
232            loop {
233                let lock = self.rooms.entry(room_id).or_default().clone();
234                let _guard = lock.lock_owned().await;
235                let (tx, result) = self.with_transaction(&f).await?;
236                match result {
237                    Ok(data) => match tx.commit().await.map_err(Into::into) {
238                        Ok(()) => {
239                            return Ok(RoomGuard {
240                                data,
241                                _guard,
242                                _not_send: PhantomData,
243                            });
244                        }
245                        Err(error) => {
246                            if !self.retry_on_serialization_error(&error, i).await {
247                                return Err(error);
248                            }
249                        }
250                    },
251                    Err(error) => {
252                        tx.rollback().await?;
253                        if !self.retry_on_serialization_error(&error, i).await {
254                            return Err(error);
255                        }
256                    }
257                }
258                i += 1;
259            }
260        };
261
262        self.run(body).await
263    }
264
265    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
266    where
267        F: Send + Fn(TransactionHandle) -> Fut,
268        Fut: Send + Future<Output = Result<T>>,
269    {
270        let tx = self
271            .pool
272            .begin_with_config(Some(IsolationLevel::Serializable), None)
273            .await?;
274
275        let mut tx = Arc::new(Some(tx));
276        let result = f(TransactionHandle(tx.clone())).await;
277        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
278            return Err(anyhow!(
279                "couldn't complete transaction because it's still in use"
280            ))?;
281        };
282
283        Ok((tx, result))
284    }
285
286    async fn run<F, T>(&self, future: F) -> Result<T>
287    where
288        F: Future<Output = Result<T>>,
289    {
290        #[cfg(test)]
291        {
292            if let Executor::Deterministic(executor) = &self.executor {
293                executor.simulate_random_delay().await;
294            }
295
296            self.runtime.as_ref().unwrap().block_on(future)
297        }
298
299        #[cfg(not(test))]
300        {
301            future.await
302        }
303    }
304
305    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
306        // If the error is due to a failure to serialize concurrent transactions, then retry
307        // this transaction after a delay. With each subsequent retry, double the delay duration.
308        // Also vary the delay randomly in order to ensure different database connections retry
309        // at different times.
310        if is_serialization_error(error) {
311            let base_delay = 4_u64 << prev_attempt_count.min(16);
312            let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
313            log::info!(
314                "retrying transaction after serialization error. delay: {} ms.",
315                randomized_delay
316            );
317            self.executor
318                .sleep(Duration::from_millis(randomized_delay as u64))
319                .await;
320            true
321        } else {
322            false
323        }
324    }
325}
326
327fn is_serialization_error(error: &Error) -> bool {
328    const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
329    match error {
330        Error::Database(
331            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
332            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
333        ) if error
334            .as_database_error()
335            .and_then(|error| error.code())
336            .as_deref()
337            == Some(SERIALIZATION_FAILURE_CODE) =>
338        {
339            true
340        }
341        _ => false,
342    }
343}
344
345/// A handle to a [`DatabaseTransaction`].
346pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
347
348impl Deref for TransactionHandle {
349    type Target = DatabaseTransaction;
350
351    fn deref(&self) -> &Self::Target {
352        self.0.as_ref().as_ref().unwrap()
353    }
354}
355
356/// [`RoomGuard`] keeps a database transaction alive until it is dropped.
357/// so that updates to rooms are serialized.
358pub struct RoomGuard<T> {
359    data: T,
360    _guard: OwnedMutexGuard<()>,
361    _not_send: PhantomData<Rc<()>>,
362}
363
364impl<T> Deref for RoomGuard<T> {
365    type Target = T;
366
367    fn deref(&self) -> &T {
368        &self.data
369    }
370}
371
372impl<T> DerefMut for RoomGuard<T> {
373    fn deref_mut(&mut self) -> &mut T {
374        &mut self.data
375    }
376}
377
378impl<T> RoomGuard<T> {
379    /// Returns the inner value of the guard.
380    pub fn into_inner(self) -> T {
381        self.data
382    }
383}
384
385#[derive(Clone, Debug, PartialEq, Eq)]
386pub enum Contact {
387    Accepted { user_id: UserId, busy: bool },
388    Outgoing { user_id: UserId },
389    Incoming { user_id: UserId },
390}
391
392impl Contact {
393    pub fn user_id(&self) -> UserId {
394        match self {
395            Contact::Accepted { user_id, .. } => *user_id,
396            Contact::Outgoing { user_id } => *user_id,
397            Contact::Incoming { user_id, .. } => *user_id,
398        }
399    }
400}
401
402pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
403
404pub struct CreatedChannelMessage {
405    pub message_id: MessageId,
406    pub participant_connection_ids: Vec<ConnectionId>,
407    pub channel_members: Vec<UserId>,
408    pub notifications: NotificationBatch,
409}
410
411#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
412pub struct Invite {
413    pub email_address: String,
414    pub email_confirmation_code: String,
415}
416
417#[derive(Clone, Debug, Deserialize)]
418pub struct NewSignup {
419    pub email_address: String,
420    pub platform_mac: bool,
421    pub platform_windows: bool,
422    pub platform_linux: bool,
423    pub editor_features: Vec<String>,
424    pub programming_languages: Vec<String>,
425    pub device_id: Option<String>,
426    pub added_to_mailing_list: bool,
427    pub created_at: Option<DateTime>,
428}
429
430#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
431pub struct WaitlistSummary {
432    pub count: i64,
433    pub linux_count: i64,
434    pub mac_count: i64,
435    pub windows_count: i64,
436    pub unknown_count: i64,
437}
438
439/// The parameters to create a new user.
440#[derive(Debug, Serialize, Deserialize)]
441pub struct NewUserParams {
442    pub github_login: String,
443    pub github_user_id: i32,
444}
445
446/// The result of creating a new user.
447#[derive(Debug)]
448pub struct NewUserResult {
449    pub user_id: UserId,
450    pub metrics_id: String,
451    pub inviting_user_id: Option<UserId>,
452    pub signup_device_id: Option<String>,
453}
454
455/// The result of moving a channel.
456#[derive(Debug)]
457pub struct MoveChannelResult {
458    pub participants_to_update: HashMap<UserId, ChannelsForUser>,
459    pub participants_to_remove: HashSet<UserId>,
460    pub moved_channels: HashSet<ChannelId>,
461}
462
463/// The result of renaming a channel.
464#[derive(Debug)]
465pub struct RenameChannelResult {
466    pub channel: Channel,
467    pub participants_to_update: HashMap<UserId, Channel>,
468}
469
470/// The result of creating a channel.
471#[derive(Debug)]
472pub struct CreateChannelResult {
473    pub channel: Channel,
474    pub participants_to_update: Vec<(UserId, ChannelsForUser)>,
475}
476
477/// The result of setting a channel's visibility.
478#[derive(Debug)]
479pub struct SetChannelVisibilityResult {
480    pub participants_to_update: HashMap<UserId, ChannelsForUser>,
481    pub participants_to_remove: HashSet<UserId>,
482    pub channels_to_remove: Vec<ChannelId>,
483}
484
485/// The result of updating a channel membership.
486#[derive(Debug)]
487pub struct MembershipUpdated {
488    pub channel_id: ChannelId,
489    pub new_channels: ChannelsForUser,
490    pub removed_channels: Vec<ChannelId>,
491}
492
493/// The result of setting a member's role.
494#[derive(Debug)]
495pub enum SetMemberRoleResult {
496    InviteUpdated(Channel),
497    MembershipUpdated(MembershipUpdated),
498}
499
500/// The result of inviting a member to a channel.
501#[derive(Debug)]
502pub struct InviteMemberResult {
503    pub channel: Channel,
504    pub notifications: NotificationBatch,
505}
506
507#[derive(Debug)]
508pub struct RespondToChannelInvite {
509    pub membership_update: Option<MembershipUpdated>,
510    pub notifications: NotificationBatch,
511}
512
513#[derive(Debug)]
514pub struct RemoveChannelMemberResult {
515    pub membership_update: MembershipUpdated,
516    pub notification_id: Option<NotificationId>,
517}
518
519#[derive(Debug, PartialEq, Eq, Hash)]
520pub struct Channel {
521    pub id: ChannelId,
522    pub name: String,
523    pub visibility: ChannelVisibility,
524    pub role: ChannelRole,
525    /// parent_path is the channel ids from the root to this one (not including this one)
526    pub parent_path: Vec<ChannelId>,
527}
528
529impl Channel {
530    fn from_model(value: channel::Model, role: ChannelRole) -> Self {
531        Channel {
532            id: value.id,
533            visibility: value.visibility,
534            name: value.clone().name,
535            role,
536            parent_path: value.ancestors().collect(),
537        }
538    }
539
540    pub fn to_proto(&self) -> proto::Channel {
541        proto::Channel {
542            id: self.id.to_proto(),
543            name: self.name.clone(),
544            visibility: self.visibility.into(),
545            role: self.role.into(),
546            parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(),
547        }
548    }
549}
550
551#[derive(Debug, PartialEq, Eq, Hash)]
552pub struct ChannelMember {
553    pub role: ChannelRole,
554    pub user_id: UserId,
555    pub kind: proto::channel_member::Kind,
556}
557
558impl ChannelMember {
559    pub fn to_proto(&self) -> proto::ChannelMember {
560        proto::ChannelMember {
561            role: self.role.into(),
562            user_id: self.user_id.to_proto(),
563            kind: self.kind.into(),
564        }
565    }
566}
567
568#[derive(Debug, PartialEq)]
569pub struct ChannelsForUser {
570    pub channels: Vec<Channel>,
571    pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
572    pub unseen_buffer_changes: Vec<proto::UnseenChannelBufferChange>,
573    pub channel_messages: Vec<proto::UnseenChannelMessage>,
574}
575
576#[derive(Debug)]
577pub struct RejoinedChannelBuffer {
578    pub buffer: proto::RejoinedChannelBuffer,
579    pub old_connection_id: ConnectionId,
580}
581
582#[derive(Clone)]
583pub struct JoinRoom {
584    pub room: proto::Room,
585    pub channel_id: Option<ChannelId>,
586    pub channel_members: Vec<UserId>,
587}
588
589pub struct RejoinedRoom {
590    pub room: proto::Room,
591    pub rejoined_projects: Vec<RejoinedProject>,
592    pub reshared_projects: Vec<ResharedProject>,
593    pub channel_id: Option<ChannelId>,
594    pub channel_members: Vec<UserId>,
595}
596
597pub struct ResharedProject {
598    pub id: ProjectId,
599    pub old_connection_id: ConnectionId,
600    pub collaborators: Vec<ProjectCollaborator>,
601    pub worktrees: Vec<proto::WorktreeMetadata>,
602}
603
604pub struct RejoinedProject {
605    pub id: ProjectId,
606    pub old_connection_id: ConnectionId,
607    pub collaborators: Vec<ProjectCollaborator>,
608    pub worktrees: Vec<RejoinedWorktree>,
609    pub language_servers: Vec<proto::LanguageServer>,
610}
611
612#[derive(Debug)]
613pub struct RejoinedWorktree {
614    pub id: u64,
615    pub abs_path: String,
616    pub root_name: String,
617    pub visible: bool,
618    pub updated_entries: Vec<proto::Entry>,
619    pub removed_entries: Vec<u64>,
620    pub updated_repositories: Vec<proto::RepositoryEntry>,
621    pub removed_repositories: Vec<u64>,
622    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
623    pub settings_files: Vec<WorktreeSettingsFile>,
624    pub scan_id: u64,
625    pub completed_scan_id: u64,
626}
627
628pub struct LeftRoom {
629    pub room: proto::Room,
630    pub channel_id: Option<ChannelId>,
631    pub channel_members: Vec<UserId>,
632    pub left_projects: HashMap<ProjectId, LeftProject>,
633    pub canceled_calls_to_user_ids: Vec<UserId>,
634    pub deleted: bool,
635}
636
637pub struct RefreshedRoom {
638    pub room: proto::Room,
639    pub channel_id: Option<ChannelId>,
640    pub channel_members: Vec<UserId>,
641    pub stale_participant_user_ids: Vec<UserId>,
642    pub canceled_calls_to_user_ids: Vec<UserId>,
643}
644
645pub struct RefreshedChannelBuffer {
646    pub connection_ids: Vec<ConnectionId>,
647    pub collaborators: Vec<proto::Collaborator>,
648}
649
650pub struct Project {
651    pub collaborators: Vec<ProjectCollaborator>,
652    pub worktrees: BTreeMap<u64, Worktree>,
653    pub language_servers: Vec<proto::LanguageServer>,
654}
655
656pub struct ProjectCollaborator {
657    pub connection_id: ConnectionId,
658    pub user_id: UserId,
659    pub replica_id: ReplicaId,
660    pub is_host: bool,
661}
662
663impl ProjectCollaborator {
664    pub fn to_proto(&self) -> proto::Collaborator {
665        proto::Collaborator {
666            peer_id: Some(self.connection_id.into()),
667            replica_id: self.replica_id.0 as u32,
668            user_id: self.user_id.to_proto(),
669        }
670    }
671}
672
673#[derive(Debug)]
674pub struct LeftProject {
675    pub id: ProjectId,
676    pub host_user_id: UserId,
677    pub host_connection_id: ConnectionId,
678    pub connection_ids: Vec<ConnectionId>,
679}
680
681pub struct Worktree {
682    pub id: u64,
683    pub abs_path: String,
684    pub root_name: String,
685    pub visible: bool,
686    pub entries: Vec<proto::Entry>,
687    pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
688    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
689    pub settings_files: Vec<WorktreeSettingsFile>,
690    pub scan_id: u64,
691    pub completed_scan_id: u64,
692}
693
694#[derive(Debug)]
695pub struct WorktreeSettingsFile {
696    pub path: String,
697    pub content: String,
698}