db.rs

  1mod ids;
  2mod queries;
  3mod tables;
  4#[cfg(test)]
  5pub mod tests;
  6
  7use crate::{executor::Executor, Error, Result};
  8use anyhow::anyhow;
  9use collections::{BTreeMap, HashMap, HashSet};
 10use dashmap::DashMap;
 11use futures::StreamExt;
 12use rand::{prelude::StdRng, Rng, SeedableRng};
 13use rpc::{
 14    proto::{self},
 15    ConnectionId, ExtensionMetadata,
 16};
 17use sea_orm::{
 18    entity::prelude::*,
 19    sea_query::{Alias, Expr, OnConflict},
 20    ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
 21    FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
 22    TransactionTrait,
 23};
 24use semantic_version::SemanticVersion;
 25use serde::{Deserialize, Serialize};
 26use sqlx::{
 27    migrate::{Migrate, Migration, MigrationSource},
 28    Connection,
 29};
 30use std::ops::RangeInclusive;
 31use std::{
 32    fmt::Write as _,
 33    future::Future,
 34    marker::PhantomData,
 35    ops::{Deref, DerefMut},
 36    path::Path,
 37    rc::Rc,
 38    sync::Arc,
 39    time::Duration,
 40};
 41use time::PrimitiveDateTime;
 42use tokio::sync::{Mutex, OwnedMutexGuard};
 43
 44#[cfg(test)]
 45pub use tests::TestDb;
 46
 47pub use ids::*;
 48pub use queries::contributors::ContributorSelector;
 49pub use sea_orm::ConnectOptions;
 50pub use tables::user::Model as User;
 51pub use tables::*;
 52
 53/// Database gives you a handle that lets you access the database.
 54/// It handles pooling internally.
 55pub struct Database {
 56    options: ConnectOptions,
 57    pool: DatabaseConnection,
 58    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 59    rng: Mutex<StdRng>,
 60    executor: Executor,
 61    notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
 62    notification_kinds_by_name: HashMap<String, NotificationKindId>,
 63    #[cfg(test)]
 64    runtime: Option<tokio::runtime::Runtime>,
 65}
 66
 67// The `Database` type has so many methods that its impl blocks are split into
 68// separate files in the `queries` folder.
 69impl Database {
 70    /// Connects to the database with the given options
 71    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 72        sqlx::any::install_default_drivers();
 73        Ok(Self {
 74            options: options.clone(),
 75            pool: sea_orm::Database::connect(options).await?,
 76            rooms: DashMap::with_capacity(16384),
 77            rng: Mutex::new(StdRng::seed_from_u64(0)),
 78            notification_kinds_by_id: HashMap::default(),
 79            notification_kinds_by_name: HashMap::default(),
 80            executor,
 81            #[cfg(test)]
 82            runtime: None,
 83        })
 84    }
 85
 86    #[cfg(test)]
 87    pub fn reset(&self) {
 88        self.rooms.clear();
 89    }
 90
 91    /// Runs the database migrations.
 92    pub async fn migrate(
 93        &self,
 94        migrations_path: &Path,
 95        ignore_checksum_mismatch: bool,
 96    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 97        let migrations = MigrationSource::resolve(migrations_path)
 98            .await
 99            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
100
101        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
102
103        connection.ensure_migrations_table().await?;
104        let applied_migrations: HashMap<_, _> = connection
105            .list_applied_migrations()
106            .await?
107            .into_iter()
108            .map(|m| (m.version, m))
109            .collect();
110
111        let mut new_migrations = Vec::new();
112        for migration in migrations {
113            match applied_migrations.get(&migration.version) {
114                Some(applied_migration) => {
115                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
116                    {
117                        Err(anyhow!(
118                            "checksum mismatch for applied migration {}",
119                            migration.description
120                        ))?;
121                    }
122                }
123                None => {
124                    let elapsed = connection.apply(&migration).await?;
125                    new_migrations.push((migration, elapsed));
126                }
127            }
128        }
129
130        Ok(new_migrations)
131    }
132
133    /// Transaction runs things in a transaction. If you want to call other methods
134    /// and pass the transaction around you need to reborrow the transaction at each
135    /// call site with: `&*tx`.
136    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
137    where
138        F: Send + Fn(TransactionHandle) -> Fut,
139        Fut: Send + Future<Output = Result<T>>,
140    {
141        let body = async {
142            let mut i = 0;
143            loop {
144                let (tx, result) = self.with_transaction(&f).await?;
145                match result {
146                    Ok(result) => match tx.commit().await.map_err(Into::into) {
147                        Ok(()) => return Ok(result),
148                        Err(error) => {
149                            if !self.retry_on_serialization_error(&error, i).await {
150                                return Err(error);
151                            }
152                        }
153                    },
154                    Err(error) => {
155                        tx.rollback().await?;
156                        if !self.retry_on_serialization_error(&error, i).await {
157                            return Err(error);
158                        }
159                    }
160                }
161                i += 1;
162            }
163        };
164
165        self.run(body).await
166    }
167
168    pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
169    where
170        F: Send + Fn(TransactionHandle) -> Fut,
171        Fut: Send + Future<Output = Result<T>>,
172    {
173        let body = async {
174            let (tx, result) = self.with_weak_transaction(&f).await?;
175            match result {
176                Ok(result) => match tx.commit().await.map_err(Into::into) {
177                    Ok(()) => return Ok(result),
178                    Err(error) => {
179                        return Err(error);
180                    }
181                },
182                Err(error) => {
183                    tx.rollback().await?;
184                    return Err(error);
185                }
186            }
187        };
188
189        self.run(body).await
190    }
191
192    /// The same as room_transaction, but if you need to only optionally return a Room.
193    async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
194    where
195        F: Send + Fn(TransactionHandle) -> Fut,
196        Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
197    {
198        let body = async {
199            let mut i = 0;
200            loop {
201                let (tx, result) = self.with_transaction(&f).await?;
202                match result {
203                    Ok(Some((room_id, data))) => {
204                        let lock = self.rooms.entry(room_id).or_default().clone();
205                        let _guard = lock.lock_owned().await;
206                        match tx.commit().await.map_err(Into::into) {
207                            Ok(()) => {
208                                return Ok(Some(RoomGuard {
209                                    data,
210                                    _guard,
211                                    _not_send: PhantomData,
212                                }));
213                            }
214                            Err(error) => {
215                                if !self.retry_on_serialization_error(&error, i).await {
216                                    return Err(error);
217                                }
218                            }
219                        }
220                    }
221                    Ok(None) => match tx.commit().await.map_err(Into::into) {
222                        Ok(()) => return Ok(None),
223                        Err(error) => {
224                            if !self.retry_on_serialization_error(&error, i).await {
225                                return Err(error);
226                            }
227                        }
228                    },
229                    Err(error) => {
230                        tx.rollback().await?;
231                        if !self.retry_on_serialization_error(&error, i).await {
232                            return Err(error);
233                        }
234                    }
235                }
236                i += 1;
237            }
238        };
239
240        self.run(body).await
241    }
242
243    /// room_transaction runs the block in a transaction. It returns a RoomGuard, that keeps
244    /// the database locked until it is dropped. This ensures that updates sent to clients are
245    /// properly serialized with respect to database changes.
246    async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
247    where
248        F: Send + Fn(TransactionHandle) -> Fut,
249        Fut: Send + Future<Output = Result<T>>,
250    {
251        let body = async {
252            let mut i = 0;
253            loop {
254                let lock = self.rooms.entry(room_id).or_default().clone();
255                let _guard = lock.lock_owned().await;
256                let (tx, result) = self.with_transaction(&f).await?;
257                match result {
258                    Ok(data) => match tx.commit().await.map_err(Into::into) {
259                        Ok(()) => {
260                            return Ok(RoomGuard {
261                                data,
262                                _guard,
263                                _not_send: PhantomData,
264                            });
265                        }
266                        Err(error) => {
267                            if !self.retry_on_serialization_error(&error, i).await {
268                                return Err(error);
269                            }
270                        }
271                    },
272                    Err(error) => {
273                        tx.rollback().await?;
274                        if !self.retry_on_serialization_error(&error, i).await {
275                            return Err(error);
276                        }
277                    }
278                }
279                i += 1;
280            }
281        };
282
283        self.run(body).await
284    }
285
286    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
287    where
288        F: Send + Fn(TransactionHandle) -> Fut,
289        Fut: Send + Future<Output = Result<T>>,
290    {
291        let tx = self
292            .pool
293            .begin_with_config(Some(IsolationLevel::Serializable), None)
294            .await?;
295
296        let mut tx = Arc::new(Some(tx));
297        let result = f(TransactionHandle(tx.clone())).await;
298        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
299            return Err(anyhow!(
300                "couldn't complete transaction because it's still in use"
301            ))?;
302        };
303
304        Ok((tx, result))
305    }
306
307    async fn with_weak_transaction<F, Fut, T>(
308        &self,
309        f: &F,
310    ) -> Result<(DatabaseTransaction, Result<T>)>
311    where
312        F: Send + Fn(TransactionHandle) -> Fut,
313        Fut: Send + Future<Output = Result<T>>,
314    {
315        let tx = self
316            .pool
317            .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
318            .await?;
319
320        let mut tx = Arc::new(Some(tx));
321        let result = f(TransactionHandle(tx.clone())).await;
322        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
323            return Err(anyhow!(
324                "couldn't complete transaction because it's still in use"
325            ))?;
326        };
327
328        Ok((tx, result))
329    }
330
331    async fn run<F, T>(&self, future: F) -> Result<T>
332    where
333        F: Future<Output = Result<T>>,
334    {
335        #[cfg(test)]
336        {
337            if let Executor::Deterministic(executor) = &self.executor {
338                executor.simulate_random_delay().await;
339            }
340
341            self.runtime.as_ref().unwrap().block_on(future)
342        }
343
344        #[cfg(not(test))]
345        {
346            future.await
347        }
348    }
349
350    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: usize) -> bool {
351        // If the error is due to a failure to serialize concurrent transactions, then retry
352        // this transaction after a delay. With each subsequent retry, double the delay duration.
353        // Also vary the delay randomly in order to ensure different database connections retry
354        // at different times.
355        const SLEEPS: [f32; 10] = [10., 20., 40., 80., 160., 320., 640., 1280., 2560., 5120.];
356        if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() {
357            let base_delay = SLEEPS[prev_attempt_count];
358            let randomized_delay = base_delay * self.rng.lock().await.gen_range(0.5..=2.0);
359            log::info!(
360                "retrying transaction after serialization error. delay: {} ms.",
361                randomized_delay
362            );
363            self.executor
364                .sleep(Duration::from_millis(randomized_delay as u64))
365                .await;
366            true
367        } else {
368            false
369        }
370    }
371}
372
373fn is_serialization_error(error: &Error) -> bool {
374    const SERIALIZATION_FAILURE_CODE: &str = "40001";
375    match error {
376        Error::Database(
377            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
378            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
379        ) if error
380            .as_database_error()
381            .and_then(|error| error.code())
382            .as_deref()
383            == Some(SERIALIZATION_FAILURE_CODE) =>
384        {
385            true
386        }
387        _ => false,
388    }
389}
390
391/// A handle to a [`DatabaseTransaction`].
392pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
393
394impl Deref for TransactionHandle {
395    type Target = DatabaseTransaction;
396
397    fn deref(&self) -> &Self::Target {
398        self.0.as_ref().as_ref().unwrap()
399    }
400}
401
402/// [`RoomGuard`] keeps a database transaction alive until it is dropped.
403/// so that updates to rooms are serialized.
404pub struct RoomGuard<T> {
405    data: T,
406    _guard: OwnedMutexGuard<()>,
407    _not_send: PhantomData<Rc<()>>,
408}
409
410impl<T> Deref for RoomGuard<T> {
411    type Target = T;
412
413    fn deref(&self) -> &T {
414        &self.data
415    }
416}
417
418impl<T> DerefMut for RoomGuard<T> {
419    fn deref_mut(&mut self) -> &mut T {
420        &mut self.data
421    }
422}
423
424impl<T> RoomGuard<T> {
425    /// Returns the inner value of the guard.
426    pub fn into_inner(self) -> T {
427        self.data
428    }
429}
430
431#[derive(Clone, Debug, PartialEq, Eq)]
432pub enum Contact {
433    Accepted { user_id: UserId, busy: bool },
434    Outgoing { user_id: UserId },
435    Incoming { user_id: UserId },
436}
437
438impl Contact {
439    pub fn user_id(&self) -> UserId {
440        match self {
441            Contact::Accepted { user_id, .. } => *user_id,
442            Contact::Outgoing { user_id } => *user_id,
443            Contact::Incoming { user_id, .. } => *user_id,
444        }
445    }
446}
447
448pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
449
450pub struct CreatedChannelMessage {
451    pub message_id: MessageId,
452    pub participant_connection_ids: Vec<ConnectionId>,
453    pub channel_members: Vec<UserId>,
454    pub notifications: NotificationBatch,
455}
456
457pub struct UpdatedChannelMessage {
458    pub message_id: MessageId,
459    pub participant_connection_ids: Vec<ConnectionId>,
460    pub notifications: NotificationBatch,
461    pub reply_to_message_id: Option<MessageId>,
462    pub timestamp: PrimitiveDateTime,
463}
464
465#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
466pub struct Invite {
467    pub email_address: String,
468    pub email_confirmation_code: String,
469}
470
471#[derive(Clone, Debug, Deserialize)]
472pub struct NewSignup {
473    pub email_address: String,
474    pub platform_mac: bool,
475    pub platform_windows: bool,
476    pub platform_linux: bool,
477    pub editor_features: Vec<String>,
478    pub programming_languages: Vec<String>,
479    pub device_id: Option<String>,
480    pub added_to_mailing_list: bool,
481    pub created_at: Option<DateTime>,
482}
483
484#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
485pub struct WaitlistSummary {
486    pub count: i64,
487    pub linux_count: i64,
488    pub mac_count: i64,
489    pub windows_count: i64,
490    pub unknown_count: i64,
491}
492
493/// The parameters to create a new user.
494#[derive(Debug, Serialize, Deserialize)]
495pub struct NewUserParams {
496    pub github_login: String,
497    pub github_user_id: i32,
498}
499
500/// The result of creating a new user.
501#[derive(Debug)]
502pub struct NewUserResult {
503    pub user_id: UserId,
504    pub metrics_id: String,
505    pub inviting_user_id: Option<UserId>,
506    pub signup_device_id: Option<String>,
507}
508
509/// The result of updating a channel membership.
510#[derive(Debug)]
511pub struct MembershipUpdated {
512    pub channel_id: ChannelId,
513    pub new_channels: ChannelsForUser,
514    pub removed_channels: Vec<ChannelId>,
515}
516
517/// The result of setting a member's role.
518#[derive(Debug)]
519pub enum SetMemberRoleResult {
520    InviteUpdated(Channel),
521    MembershipUpdated(MembershipUpdated),
522}
523
524/// The result of inviting a member to a channel.
525#[derive(Debug)]
526pub struct InviteMemberResult {
527    pub channel: Channel,
528    pub notifications: NotificationBatch,
529}
530
531#[derive(Debug)]
532pub struct RespondToChannelInvite {
533    pub membership_update: Option<MembershipUpdated>,
534    pub notifications: NotificationBatch,
535}
536
537#[derive(Debug)]
538pub struct RemoveChannelMemberResult {
539    pub membership_update: MembershipUpdated,
540    pub notification_id: Option<NotificationId>,
541}
542
543#[derive(Debug, PartialEq, Eq, Hash)]
544pub struct Channel {
545    pub id: ChannelId,
546    pub name: String,
547    pub visibility: ChannelVisibility,
548    /// parent_path is the channel ids from the root to this one (not including this one)
549    pub parent_path: Vec<ChannelId>,
550}
551
552impl Channel {
553    pub fn from_model(value: channel::Model) -> Self {
554        Channel {
555            id: value.id,
556            visibility: value.visibility,
557            name: value.clone().name,
558            parent_path: value.ancestors().collect(),
559        }
560    }
561
562    pub fn to_proto(&self) -> proto::Channel {
563        proto::Channel {
564            id: self.id.to_proto(),
565            name: self.name.clone(),
566            visibility: self.visibility.into(),
567            parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(),
568        }
569    }
570}
571
572#[derive(Debug, PartialEq, Eq, Hash)]
573pub struct ChannelMember {
574    pub role: ChannelRole,
575    pub user_id: UserId,
576    pub kind: proto::channel_member::Kind,
577}
578
579impl ChannelMember {
580    pub fn to_proto(&self) -> proto::ChannelMember {
581        proto::ChannelMember {
582            role: self.role.into(),
583            user_id: self.user_id.to_proto(),
584            kind: self.kind.into(),
585        }
586    }
587}
588
589#[derive(Debug, PartialEq)]
590pub struct ChannelsForUser {
591    pub channels: Vec<Channel>,
592    pub channel_memberships: Vec<channel_member::Model>,
593    pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
594    pub hosted_projects: Vec<proto::HostedProject>,
595
596    pub observed_buffer_versions: Vec<proto::ChannelBufferVersion>,
597    pub observed_channel_messages: Vec<proto::ChannelMessageId>,
598    pub latest_buffer_versions: Vec<proto::ChannelBufferVersion>,
599    pub latest_channel_messages: Vec<proto::ChannelMessageId>,
600}
601
602#[derive(Debug)]
603pub struct RejoinedChannelBuffer {
604    pub buffer: proto::RejoinedChannelBuffer,
605    pub old_connection_id: ConnectionId,
606}
607
608#[derive(Clone)]
609pub struct JoinRoom {
610    pub room: proto::Room,
611    pub channel: Option<channel::Model>,
612}
613
614pub struct RejoinedRoom {
615    pub room: proto::Room,
616    pub rejoined_projects: Vec<RejoinedProject>,
617    pub reshared_projects: Vec<ResharedProject>,
618    pub channel: Option<channel::Model>,
619}
620
621pub struct ResharedProject {
622    pub id: ProjectId,
623    pub old_connection_id: ConnectionId,
624    pub collaborators: Vec<ProjectCollaborator>,
625    pub worktrees: Vec<proto::WorktreeMetadata>,
626}
627
628pub struct RejoinedProject {
629    pub id: ProjectId,
630    pub old_connection_id: ConnectionId,
631    pub collaborators: Vec<ProjectCollaborator>,
632    pub worktrees: Vec<RejoinedWorktree>,
633    pub language_servers: Vec<proto::LanguageServer>,
634}
635
636#[derive(Debug)]
637pub struct RejoinedWorktree {
638    pub id: u64,
639    pub abs_path: String,
640    pub root_name: String,
641    pub visible: bool,
642    pub updated_entries: Vec<proto::Entry>,
643    pub removed_entries: Vec<u64>,
644    pub updated_repositories: Vec<proto::RepositoryEntry>,
645    pub removed_repositories: Vec<u64>,
646    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
647    pub settings_files: Vec<WorktreeSettingsFile>,
648    pub scan_id: u64,
649    pub completed_scan_id: u64,
650}
651
652pub struct LeftRoom {
653    pub room: proto::Room,
654    pub channel: Option<channel::Model>,
655    pub left_projects: HashMap<ProjectId, LeftProject>,
656    pub canceled_calls_to_user_ids: Vec<UserId>,
657    pub deleted: bool,
658}
659
660pub struct RefreshedRoom {
661    pub room: proto::Room,
662    pub channel: Option<channel::Model>,
663    pub stale_participant_user_ids: Vec<UserId>,
664    pub canceled_calls_to_user_ids: Vec<UserId>,
665}
666
667pub struct RefreshedChannelBuffer {
668    pub connection_ids: Vec<ConnectionId>,
669    pub collaborators: Vec<proto::Collaborator>,
670}
671
672pub struct Project {
673    pub id: ProjectId,
674    pub role: ChannelRole,
675    pub collaborators: Vec<ProjectCollaborator>,
676    pub worktrees: BTreeMap<u64, Worktree>,
677    pub language_servers: Vec<proto::LanguageServer>,
678}
679
680pub struct ProjectCollaborator {
681    pub connection_id: ConnectionId,
682    pub user_id: UserId,
683    pub replica_id: ReplicaId,
684    pub is_host: bool,
685}
686
687impl ProjectCollaborator {
688    pub fn to_proto(&self) -> proto::Collaborator {
689        proto::Collaborator {
690            peer_id: Some(self.connection_id.into()),
691            replica_id: self.replica_id.0 as u32,
692            user_id: self.user_id.to_proto(),
693        }
694    }
695}
696
697#[derive(Debug)]
698pub struct LeftProject {
699    pub id: ProjectId,
700    pub host_user_id: Option<UserId>,
701    pub host_connection_id: Option<ConnectionId>,
702    pub connection_ids: Vec<ConnectionId>,
703}
704
705pub struct Worktree {
706    pub id: u64,
707    pub abs_path: String,
708    pub root_name: String,
709    pub visible: bool,
710    pub entries: Vec<proto::Entry>,
711    pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
712    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
713    pub settings_files: Vec<WorktreeSettingsFile>,
714    pub scan_id: u64,
715    pub completed_scan_id: u64,
716}
717
718#[derive(Debug)]
719pub struct WorktreeSettingsFile {
720    pub path: String,
721    pub content: String,
722}
723
724pub struct NewExtensionVersion {
725    pub name: String,
726    pub version: semver::Version,
727    pub description: String,
728    pub authors: Vec<String>,
729    pub repository: String,
730    pub schema_version: i32,
731    pub wasm_api_version: Option<String>,
732    pub published_at: PrimitiveDateTime,
733}
734
735pub struct ExtensionVersionConstraints {
736    pub schema_versions: RangeInclusive<i32>,
737    pub wasm_api_versions: RangeInclusive<SemanticVersion>,
738}