db.rs

  1#[cfg(test)]
  2pub mod tests;
  3
  4#[cfg(test)]
  5pub use tests::TestDb;
  6
  7mod ids;
  8mod queries;
  9mod tables;
 10
 11use crate::{executor::Executor, Error, Result};
 12use anyhow::anyhow;
 13use collections::{BTreeMap, HashMap, HashSet};
 14use dashmap::DashMap;
 15use futures::StreamExt;
 16use rand::{prelude::StdRng, Rng, SeedableRng};
 17use rpc::{
 18    proto::{self},
 19    ConnectionId,
 20};
 21use sea_orm::{
 22    entity::prelude::*,
 23    sea_query::{Alias, Expr, OnConflict},
 24    ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
 25    FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
 26    TransactionTrait,
 27};
 28use serde::{Deserialize, Serialize};
 29use sqlx::{
 30    migrate::{Migrate, Migration, MigrationSource},
 31    Connection,
 32};
 33use std::{
 34    fmt::Write as _,
 35    future::Future,
 36    marker::PhantomData,
 37    ops::{Deref, DerefMut},
 38    path::Path,
 39    rc::Rc,
 40    sync::Arc,
 41    time::Duration,
 42};
 43pub use tables::*;
 44use tokio::sync::{Mutex, OwnedMutexGuard};
 45
 46pub use ids::*;
 47pub use queries::contributors::ContributorSelector;
 48pub use sea_orm::ConnectOptions;
 49pub use tables::user::Model as User;
 50
 51/// Database gives you a handle that lets you access the database.
 52/// It handles pooling internally.
 53pub struct Database {
 54    options: ConnectOptions,
 55    pool: DatabaseConnection,
 56    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 57    rng: Mutex<StdRng>,
 58    executor: Executor,
 59    notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
 60    notification_kinds_by_name: HashMap<String, NotificationKindId>,
 61    #[cfg(test)]
 62    runtime: Option<tokio::runtime::Runtime>,
 63}
 64
 65// The `Database` type has so many methods that its impl blocks are split into
 66// separate files in the `queries` folder.
 67impl Database {
 68    /// Connects to the database with the given options
 69    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 70        sqlx::any::install_default_drivers();
 71        Ok(Self {
 72            options: options.clone(),
 73            pool: sea_orm::Database::connect(options).await?,
 74            rooms: DashMap::with_capacity(16384),
 75            rng: Mutex::new(StdRng::seed_from_u64(0)),
 76            notification_kinds_by_id: HashMap::default(),
 77            notification_kinds_by_name: HashMap::default(),
 78            executor,
 79            #[cfg(test)]
 80            runtime: None,
 81        })
 82    }
 83
 84    #[cfg(test)]
 85    pub fn reset(&self) {
 86        self.rooms.clear();
 87    }
 88
 89    /// Runs the database migrations.
 90    pub async fn migrate(
 91        &self,
 92        migrations_path: &Path,
 93        ignore_checksum_mismatch: bool,
 94    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 95        let migrations = MigrationSource::resolve(migrations_path)
 96            .await
 97            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 98
 99        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
100
101        connection.ensure_migrations_table().await?;
102        let applied_migrations: HashMap<_, _> = connection
103            .list_applied_migrations()
104            .await?
105            .into_iter()
106            .map(|m| (m.version, m))
107            .collect();
108
109        let mut new_migrations = Vec::new();
110        for migration in migrations {
111            match applied_migrations.get(&migration.version) {
112                Some(applied_migration) => {
113                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
114                    {
115                        Err(anyhow!(
116                            "checksum mismatch for applied migration {}",
117                            migration.description
118                        ))?;
119                    }
120                }
121                None => {
122                    let elapsed = connection.apply(&migration).await?;
123                    new_migrations.push((migration, elapsed));
124                }
125            }
126        }
127
128        Ok(new_migrations)
129    }
130
131    /// Initializes static data that resides in the database by upserting it.
132    pub async fn initialize_static_data(&mut self) -> Result<()> {
133        self.initialize_notification_kinds().await?;
134        Ok(())
135    }
136
137    /// Transaction runs things in a transaction. If you want to call other methods
138    /// and pass the transaction around you need to reborrow the transaction at each
139    /// call site with: `&*tx`.
140    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
141    where
142        F: Send + Fn(TransactionHandle) -> Fut,
143        Fut: Send + Future<Output = Result<T>>,
144    {
145        let body = async {
146            let mut i = 0;
147            loop {
148                let (tx, result) = self.with_transaction(&f).await?;
149                match result {
150                    Ok(result) => match tx.commit().await.map_err(Into::into) {
151                        Ok(()) => return Ok(result),
152                        Err(error) => {
153                            if !self.retry_on_serialization_error(&error, i).await {
154                                return Err(error);
155                            }
156                        }
157                    },
158                    Err(error) => {
159                        tx.rollback().await?;
160                        if !self.retry_on_serialization_error(&error, i).await {
161                            return Err(error);
162                        }
163                    }
164                }
165                i += 1;
166            }
167        };
168
169        self.run(body).await
170    }
171
172    pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
173    where
174        F: Send + Fn(TransactionHandle) -> Fut,
175        Fut: Send + Future<Output = Result<T>>,
176    {
177        let body = async {
178            let (tx, result) = self.with_weak_transaction(&f).await?;
179            match result {
180                Ok(result) => match tx.commit().await.map_err(Into::into) {
181                    Ok(()) => return Ok(result),
182                    Err(error) => {
183                        return Err(error);
184                    }
185                },
186                Err(error) => {
187                    tx.rollback().await?;
188                    return Err(error);
189                }
190            }
191        };
192
193        self.run(body).await
194    }
195
196    /// The same as room_transaction, but if you need to only optionally return a Room.
197    async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
198    where
199        F: Send + Fn(TransactionHandle) -> Fut,
200        Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
201    {
202        let body = async {
203            let mut i = 0;
204            loop {
205                let (tx, result) = self.with_transaction(&f).await?;
206                match result {
207                    Ok(Some((room_id, data))) => {
208                        let lock = self.rooms.entry(room_id).or_default().clone();
209                        let _guard = lock.lock_owned().await;
210                        match tx.commit().await.map_err(Into::into) {
211                            Ok(()) => {
212                                return Ok(Some(RoomGuard {
213                                    data,
214                                    _guard,
215                                    _not_send: PhantomData,
216                                }));
217                            }
218                            Err(error) => {
219                                if !self.retry_on_serialization_error(&error, i).await {
220                                    return Err(error);
221                                }
222                            }
223                        }
224                    }
225                    Ok(None) => match tx.commit().await.map_err(Into::into) {
226                        Ok(()) => return Ok(None),
227                        Err(error) => {
228                            if !self.retry_on_serialization_error(&error, i).await {
229                                return Err(error);
230                            }
231                        }
232                    },
233                    Err(error) => {
234                        tx.rollback().await?;
235                        if !self.retry_on_serialization_error(&error, i).await {
236                            return Err(error);
237                        }
238                    }
239                }
240                i += 1;
241            }
242        };
243
244        self.run(body).await
245    }
246
247    /// room_transaction runs the block in a transaction. It returns a RoomGuard, that keeps
248    /// the database locked until it is dropped. This ensures that updates sent to clients are
249    /// properly serialized with respect to database changes.
250    async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
251    where
252        F: Send + Fn(TransactionHandle) -> Fut,
253        Fut: Send + Future<Output = Result<T>>,
254    {
255        let body = async {
256            let mut i = 0;
257            loop {
258                let lock = self.rooms.entry(room_id).or_default().clone();
259                let _guard = lock.lock_owned().await;
260                let (tx, result) = self.with_transaction(&f).await?;
261                match result {
262                    Ok(data) => match tx.commit().await.map_err(Into::into) {
263                        Ok(()) => {
264                            return Ok(RoomGuard {
265                                data,
266                                _guard,
267                                _not_send: PhantomData,
268                            });
269                        }
270                        Err(error) => {
271                            if !self.retry_on_serialization_error(&error, i).await {
272                                return Err(error);
273                            }
274                        }
275                    },
276                    Err(error) => {
277                        tx.rollback().await?;
278                        if !self.retry_on_serialization_error(&error, i).await {
279                            return Err(error);
280                        }
281                    }
282                }
283                i += 1;
284            }
285        };
286
287        self.run(body).await
288    }
289
290    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
291    where
292        F: Send + Fn(TransactionHandle) -> Fut,
293        Fut: Send + Future<Output = Result<T>>,
294    {
295        let tx = self
296            .pool
297            .begin_with_config(Some(IsolationLevel::Serializable), None)
298            .await?;
299
300        let mut tx = Arc::new(Some(tx));
301        let result = f(TransactionHandle(tx.clone())).await;
302        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
303            return Err(anyhow!(
304                "couldn't complete transaction because it's still in use"
305            ))?;
306        };
307
308        Ok((tx, result))
309    }
310
311    async fn with_weak_transaction<F, Fut, T>(
312        &self,
313        f: &F,
314    ) -> Result<(DatabaseTransaction, Result<T>)>
315    where
316        F: Send + Fn(TransactionHandle) -> Fut,
317        Fut: Send + Future<Output = Result<T>>,
318    {
319        let tx = self
320            .pool
321            .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
322            .await?;
323
324        let mut tx = Arc::new(Some(tx));
325        let result = f(TransactionHandle(tx.clone())).await;
326        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
327            return Err(anyhow!(
328                "couldn't complete transaction because it's still in use"
329            ))?;
330        };
331
332        Ok((tx, result))
333    }
334
335    async fn run<F, T>(&self, future: F) -> Result<T>
336    where
337        F: Future<Output = Result<T>>,
338    {
339        #[cfg(test)]
340        {
341            if let Executor::Deterministic(executor) = &self.executor {
342                executor.simulate_random_delay().await;
343            }
344
345            self.runtime.as_ref().unwrap().block_on(future)
346        }
347
348        #[cfg(not(test))]
349        {
350            future.await
351        }
352    }
353
354    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: usize) -> bool {
355        // If the error is due to a failure to serialize concurrent transactions, then retry
356        // this transaction after a delay. With each subsequent retry, double the delay duration.
357        // Also vary the delay randomly in order to ensure different database connections retry
358        // at different times.
359        const SLEEPS: [f32; 10] = [10., 20., 40., 80., 160., 320., 640., 1280., 2560., 5120.];
360        if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() {
361            let base_delay = SLEEPS[prev_attempt_count];
362            let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
363            log::info!(
364                "retrying transaction after serialization error. delay: {} ms.",
365                randomized_delay
366            );
367            self.executor
368                .sleep(Duration::from_millis(randomized_delay as u64))
369                .await;
370            true
371        } else {
372            false
373        }
374    }
375}
376
377fn is_serialization_error(error: &Error) -> bool {
378    const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
379    match error {
380        Error::Database(
381            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
382            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
383        ) if error
384            .as_database_error()
385            .and_then(|error| error.code())
386            .as_deref()
387            == Some(SERIALIZATION_FAILURE_CODE) =>
388        {
389            true
390        }
391        _ => false,
392    }
393}
394
395/// A handle to a [`DatabaseTransaction`].
396pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
397
398impl Deref for TransactionHandle {
399    type Target = DatabaseTransaction;
400
401    fn deref(&self) -> &Self::Target {
402        self.0.as_ref().as_ref().unwrap()
403    }
404}
405
406/// [`RoomGuard`] keeps a database transaction alive until it is dropped.
407/// so that updates to rooms are serialized.
408pub struct RoomGuard<T> {
409    data: T,
410    _guard: OwnedMutexGuard<()>,
411    _not_send: PhantomData<Rc<()>>,
412}
413
414impl<T> Deref for RoomGuard<T> {
415    type Target = T;
416
417    fn deref(&self) -> &T {
418        &self.data
419    }
420}
421
422impl<T> DerefMut for RoomGuard<T> {
423    fn deref_mut(&mut self) -> &mut T {
424        &mut self.data
425    }
426}
427
428impl<T> RoomGuard<T> {
429    /// Returns the inner value of the guard.
430    pub fn into_inner(self) -> T {
431        self.data
432    }
433}
434
435#[derive(Clone, Debug, PartialEq, Eq)]
436pub enum Contact {
437    Accepted { user_id: UserId, busy: bool },
438    Outgoing { user_id: UserId },
439    Incoming { user_id: UserId },
440}
441
442impl Contact {
443    pub fn user_id(&self) -> UserId {
444        match self {
445            Contact::Accepted { user_id, .. } => *user_id,
446            Contact::Outgoing { user_id } => *user_id,
447            Contact::Incoming { user_id, .. } => *user_id,
448        }
449    }
450}
451
452pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
453
454pub struct CreatedChannelMessage {
455    pub message_id: MessageId,
456    pub participant_connection_ids: Vec<ConnectionId>,
457    pub channel_members: Vec<UserId>,
458    pub notifications: NotificationBatch,
459}
460
461#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
462pub struct Invite {
463    pub email_address: String,
464    pub email_confirmation_code: String,
465}
466
467#[derive(Clone, Debug, Deserialize)]
468pub struct NewSignup {
469    pub email_address: String,
470    pub platform_mac: bool,
471    pub platform_windows: bool,
472    pub platform_linux: bool,
473    pub editor_features: Vec<String>,
474    pub programming_languages: Vec<String>,
475    pub device_id: Option<String>,
476    pub added_to_mailing_list: bool,
477    pub created_at: Option<DateTime>,
478}
479
480#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
481pub struct WaitlistSummary {
482    pub count: i64,
483    pub linux_count: i64,
484    pub mac_count: i64,
485    pub windows_count: i64,
486    pub unknown_count: i64,
487}
488
489/// The parameters to create a new user.
490#[derive(Debug, Serialize, Deserialize)]
491pub struct NewUserParams {
492    pub github_login: String,
493    pub github_user_id: i32,
494}
495
496/// The result of creating a new user.
497#[derive(Debug)]
498pub struct NewUserResult {
499    pub user_id: UserId,
500    pub metrics_id: String,
501    pub inviting_user_id: Option<UserId>,
502    pub signup_device_id: Option<String>,
503}
504
505/// The result of updating a channel membership.
506#[derive(Debug)]
507pub struct MembershipUpdated {
508    pub channel_id: ChannelId,
509    pub new_channels: ChannelsForUser,
510    pub removed_channels: Vec<ChannelId>,
511}
512
513/// The result of setting a member's role.
514#[derive(Debug)]
515pub enum SetMemberRoleResult {
516    InviteUpdated(Channel),
517    MembershipUpdated(MembershipUpdated),
518}
519
520/// The result of inviting a member to a channel.
521#[derive(Debug)]
522pub struct InviteMemberResult {
523    pub channel: Channel,
524    pub notifications: NotificationBatch,
525}
526
527#[derive(Debug)]
528pub struct RespondToChannelInvite {
529    pub membership_update: Option<MembershipUpdated>,
530    pub notifications: NotificationBatch,
531}
532
533#[derive(Debug)]
534pub struct RemoveChannelMemberResult {
535    pub membership_update: MembershipUpdated,
536    pub notification_id: Option<NotificationId>,
537}
538
539#[derive(Debug, PartialEq, Eq, Hash)]
540pub struct Channel {
541    pub id: ChannelId,
542    pub name: String,
543    pub visibility: ChannelVisibility,
544    /// parent_path is the channel ids from the root to this one (not including this one)
545    pub parent_path: Vec<ChannelId>,
546}
547
548impl Channel {
549    fn from_model(value: channel::Model) -> Self {
550        Channel {
551            id: value.id,
552            visibility: value.visibility,
553            name: value.clone().name,
554            parent_path: value.ancestors().collect(),
555        }
556    }
557
558    pub fn to_proto(&self) -> proto::Channel {
559        proto::Channel {
560            id: self.id.to_proto(),
561            name: self.name.clone(),
562            visibility: self.visibility.into(),
563            parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(),
564        }
565    }
566}
567
568#[derive(Debug, PartialEq, Eq, Hash)]
569pub struct ChannelMember {
570    pub role: ChannelRole,
571    pub user_id: UserId,
572    pub kind: proto::channel_member::Kind,
573}
574
575impl ChannelMember {
576    pub fn to_proto(&self) -> proto::ChannelMember {
577        proto::ChannelMember {
578            role: self.role.into(),
579            user_id: self.user_id.to_proto(),
580            kind: self.kind.into(),
581        }
582    }
583}
584
585#[derive(Debug, PartialEq)]
586pub struct ChannelsForUser {
587    pub channels: Vec<Channel>,
588    pub channel_memberships: Vec<channel_member::Model>,
589    pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
590    pub latest_buffer_versions: Vec<proto::ChannelBufferVersion>,
591    pub latest_channel_messages: Vec<proto::ChannelMessageId>,
592}
593
594#[derive(Debug)]
595pub struct RejoinedChannelBuffer {
596    pub buffer: proto::RejoinedChannelBuffer,
597    pub old_connection_id: ConnectionId,
598}
599
600#[derive(Clone)]
601pub struct JoinRoom {
602    pub room: proto::Room,
603    pub channel_id: Option<ChannelId>,
604    pub channel_members: Vec<UserId>,
605}
606
607pub struct RejoinedRoom {
608    pub room: proto::Room,
609    pub rejoined_projects: Vec<RejoinedProject>,
610    pub reshared_projects: Vec<ResharedProject>,
611    pub channel_id: Option<ChannelId>,
612    pub channel_members: Vec<UserId>,
613}
614
615pub struct ResharedProject {
616    pub id: ProjectId,
617    pub old_connection_id: ConnectionId,
618    pub collaborators: Vec<ProjectCollaborator>,
619    pub worktrees: Vec<proto::WorktreeMetadata>,
620}
621
622pub struct RejoinedProject {
623    pub id: ProjectId,
624    pub old_connection_id: ConnectionId,
625    pub collaborators: Vec<ProjectCollaborator>,
626    pub worktrees: Vec<RejoinedWorktree>,
627    pub language_servers: Vec<proto::LanguageServer>,
628}
629
630#[derive(Debug)]
631pub struct RejoinedWorktree {
632    pub id: u64,
633    pub abs_path: String,
634    pub root_name: String,
635    pub visible: bool,
636    pub updated_entries: Vec<proto::Entry>,
637    pub removed_entries: Vec<u64>,
638    pub updated_repositories: Vec<proto::RepositoryEntry>,
639    pub removed_repositories: Vec<u64>,
640    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
641    pub settings_files: Vec<WorktreeSettingsFile>,
642    pub scan_id: u64,
643    pub completed_scan_id: u64,
644}
645
646pub struct LeftRoom {
647    pub room: proto::Room,
648    pub channel_id: Option<ChannelId>,
649    pub channel_members: Vec<UserId>,
650    pub left_projects: HashMap<ProjectId, LeftProject>,
651    pub canceled_calls_to_user_ids: Vec<UserId>,
652    pub deleted: bool,
653}
654
655pub struct RefreshedRoom {
656    pub room: proto::Room,
657    pub channel_id: Option<ChannelId>,
658    pub channel_members: Vec<UserId>,
659    pub stale_participant_user_ids: Vec<UserId>,
660    pub canceled_calls_to_user_ids: Vec<UserId>,
661}
662
663pub struct RefreshedChannelBuffer {
664    pub connection_ids: Vec<ConnectionId>,
665    pub collaborators: Vec<proto::Collaborator>,
666}
667
668pub struct Project {
669    pub collaborators: Vec<ProjectCollaborator>,
670    pub worktrees: BTreeMap<u64, Worktree>,
671    pub language_servers: Vec<proto::LanguageServer>,
672}
673
674pub struct ProjectCollaborator {
675    pub connection_id: ConnectionId,
676    pub user_id: UserId,
677    pub replica_id: ReplicaId,
678    pub is_host: bool,
679}
680
681impl ProjectCollaborator {
682    pub fn to_proto(&self) -> proto::Collaborator {
683        proto::Collaborator {
684            peer_id: Some(self.connection_id.into()),
685            replica_id: self.replica_id.0 as u32,
686            user_id: self.user_id.to_proto(),
687        }
688    }
689}
690
691#[derive(Debug)]
692pub struct LeftProject {
693    pub id: ProjectId,
694    pub host_user_id: UserId,
695    pub host_connection_id: ConnectionId,
696    pub connection_ids: Vec<ConnectionId>,
697}
698
699pub struct Worktree {
700    pub id: u64,
701    pub abs_path: String,
702    pub root_name: String,
703    pub visible: bool,
704    pub entries: Vec<proto::Entry>,
705    pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
706    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
707    pub settings_files: Vec<WorktreeSettingsFile>,
708    pub scan_id: u64,
709    pub completed_scan_id: u64,
710}
711
712#[derive(Debug)]
713pub struct WorktreeSettingsFile {
714    pub path: String,
715    pub content: String,
716}