db.rs

  1mod ids;
  2mod queries;
  3mod tables;
  4#[cfg(test)]
  5pub mod tests;
  6
  7use crate::{executor::Executor, Error, Result};
  8use anyhow::anyhow;
  9use collections::{BTreeMap, HashMap, HashSet};
 10use dashmap::DashMap;
 11use futures::StreamExt;
 12use rand::{prelude::StdRng, Rng, SeedableRng};
 13use rpc::{
 14    proto::{self},
 15    ConnectionId, ExtensionMetadata,
 16};
 17use sea_orm::{
 18    entity::prelude::*,
 19    sea_query::{Alias, Expr, OnConflict},
 20    ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
 21    FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
 22    TransactionTrait,
 23};
 24use semantic_version::SemanticVersion;
 25use serde::{Deserialize, Serialize};
 26use sqlx::{
 27    migrate::{Migrate, Migration, MigrationSource},
 28    Connection,
 29};
 30use std::ops::RangeInclusive;
 31use std::{
 32    fmt::Write as _,
 33    future::Future,
 34    marker::PhantomData,
 35    ops::{Deref, DerefMut},
 36    path::Path,
 37    rc::Rc,
 38    sync::Arc,
 39    time::Duration,
 40};
 41use time::PrimitiveDateTime;
 42use tokio::sync::{Mutex, OwnedMutexGuard};
 43
 44#[cfg(test)]
 45pub use tests::TestDb;
 46
 47pub use ids::*;
 48pub use queries::contributors::ContributorSelector;
 49pub use sea_orm::ConnectOptions;
 50pub use tables::user::Model as User;
 51pub use tables::*;
 52
 53/// Database gives you a handle that lets you access the database.
 54/// It handles pooling internally.
 55pub struct Database {
 56    options: ConnectOptions,
 57    pool: DatabaseConnection,
 58    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 59    rng: Mutex<StdRng>,
 60    executor: Executor,
 61    notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
 62    notification_kinds_by_name: HashMap<String, NotificationKindId>,
 63    #[cfg(test)]
 64    runtime: Option<tokio::runtime::Runtime>,
 65}
 66
 67// The `Database` type has so many methods that its impl blocks are split into
 68// separate files in the `queries` folder.
 69impl Database {
 70    /// Connects to the database with the given options
 71    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 72        sqlx::any::install_default_drivers();
 73        Ok(Self {
 74            options: options.clone(),
 75            pool: sea_orm::Database::connect(options).await?,
 76            rooms: DashMap::with_capacity(16384),
 77            rng: Mutex::new(StdRng::seed_from_u64(0)),
 78            notification_kinds_by_id: HashMap::default(),
 79            notification_kinds_by_name: HashMap::default(),
 80            executor,
 81            #[cfg(test)]
 82            runtime: None,
 83        })
 84    }
 85
 86    #[cfg(test)]
 87    pub fn reset(&self) {
 88        self.rooms.clear();
 89    }
 90
 91    /// Runs the database migrations.
 92    pub async fn migrate(
 93        &self,
 94        migrations_path: &Path,
 95        ignore_checksum_mismatch: bool,
 96    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 97        let migrations = MigrationSource::resolve(migrations_path)
 98            .await
 99            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
100
101        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
102
103        connection.ensure_migrations_table().await?;
104        let applied_migrations: HashMap<_, _> = connection
105            .list_applied_migrations()
106            .await?
107            .into_iter()
108            .map(|m| (m.version, m))
109            .collect();
110
111        let mut new_migrations = Vec::new();
112        for migration in migrations {
113            match applied_migrations.get(&migration.version) {
114                Some(applied_migration) => {
115                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
116                    {
117                        Err(anyhow!(
118                            "checksum mismatch for applied migration {}",
119                            migration.description
120                        ))?;
121                    }
122                }
123                None => {
124                    let elapsed = connection.apply(&migration).await?;
125                    new_migrations.push((migration, elapsed));
126                }
127            }
128        }
129
130        Ok(new_migrations)
131    }
132
133    /// Transaction runs things in a transaction. If you want to call other methods
134    /// and pass the transaction around you need to reborrow the transaction at each
135    /// call site with: `&*tx`.
136    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
137    where
138        F: Send + Fn(TransactionHandle) -> Fut,
139        Fut: Send + Future<Output = Result<T>>,
140    {
141        let body = async {
142            let mut i = 0;
143            loop {
144                let (tx, result) = self.with_transaction(&f).await?;
145                match result {
146                    Ok(result) => match tx.commit().await.map_err(Into::into) {
147                        Ok(()) => return Ok(result),
148                        Err(error) => {
149                            if !self.retry_on_serialization_error(&error, i).await {
150                                return Err(error);
151                            }
152                        }
153                    },
154                    Err(error) => {
155                        tx.rollback().await?;
156                        if !self.retry_on_serialization_error(&error, i).await {
157                            return Err(error);
158                        }
159                    }
160                }
161                i += 1;
162            }
163        };
164
165        self.run(body).await
166    }
167
168    pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
169    where
170        F: Send + Fn(TransactionHandle) -> Fut,
171        Fut: Send + Future<Output = Result<T>>,
172    {
173        let body = async {
174            let (tx, result) = self.with_weak_transaction(&f).await?;
175            match result {
176                Ok(result) => match tx.commit().await.map_err(Into::into) {
177                    Ok(()) => return Ok(result),
178                    Err(error) => {
179                        return Err(error);
180                    }
181                },
182                Err(error) => {
183                    tx.rollback().await?;
184                    return Err(error);
185                }
186            }
187        };
188
189        self.run(body).await
190    }
191
192    /// The same as room_transaction, but if you need to only optionally return a Room.
193    async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
194    where
195        F: Send + Fn(TransactionHandle) -> Fut,
196        Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
197    {
198        let body = async {
199            let mut i = 0;
200            loop {
201                let (tx, result) = self.with_transaction(&f).await?;
202                match result {
203                    Ok(Some((room_id, data))) => {
204                        let lock = self.rooms.entry(room_id).or_default().clone();
205                        let _guard = lock.lock_owned().await;
206                        match tx.commit().await.map_err(Into::into) {
207                            Ok(()) => {
208                                return Ok(Some(RoomGuard {
209                                    data,
210                                    _guard,
211                                    _not_send: PhantomData,
212                                }));
213                            }
214                            Err(error) => {
215                                if !self.retry_on_serialization_error(&error, i).await {
216                                    return Err(error);
217                                }
218                            }
219                        }
220                    }
221                    Ok(None) => match tx.commit().await.map_err(Into::into) {
222                        Ok(()) => return Ok(None),
223                        Err(error) => {
224                            if !self.retry_on_serialization_error(&error, i).await {
225                                return Err(error);
226                            }
227                        }
228                    },
229                    Err(error) => {
230                        tx.rollback().await?;
231                        if !self.retry_on_serialization_error(&error, i).await {
232                            return Err(error);
233                        }
234                    }
235                }
236                i += 1;
237            }
238        };
239
240        self.run(body).await
241    }
242
243    /// room_transaction runs the block in a transaction. It returns a RoomGuard, that keeps
244    /// the database locked until it is dropped. This ensures that updates sent to clients are
245    /// properly serialized with respect to database changes.
246    async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
247    where
248        F: Send + Fn(TransactionHandle) -> Fut,
249        Fut: Send + Future<Output = Result<T>>,
250    {
251        let body = async {
252            let mut i = 0;
253            loop {
254                let lock = self.rooms.entry(room_id).or_default().clone();
255                let _guard = lock.lock_owned().await;
256                let (tx, result) = self.with_transaction(&f).await?;
257                match result {
258                    Ok(data) => match tx.commit().await.map_err(Into::into) {
259                        Ok(()) => {
260                            return Ok(RoomGuard {
261                                data,
262                                _guard,
263                                _not_send: PhantomData,
264                            });
265                        }
266                        Err(error) => {
267                            if !self.retry_on_serialization_error(&error, i).await {
268                                return Err(error);
269                            }
270                        }
271                    },
272                    Err(error) => {
273                        tx.rollback().await?;
274                        if !self.retry_on_serialization_error(&error, i).await {
275                            return Err(error);
276                        }
277                    }
278                }
279                i += 1;
280            }
281        };
282
283        self.run(body).await
284    }
285
286    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
287    where
288        F: Send + Fn(TransactionHandle) -> Fut,
289        Fut: Send + Future<Output = Result<T>>,
290    {
291        let tx = self
292            .pool
293            .begin_with_config(Some(IsolationLevel::Serializable), None)
294            .await?;
295
296        let mut tx = Arc::new(Some(tx));
297        let result = f(TransactionHandle(tx.clone())).await;
298        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
299            return Err(anyhow!(
300                "couldn't complete transaction because it's still in use"
301            ))?;
302        };
303
304        Ok((tx, result))
305    }
306
307    async fn with_weak_transaction<F, Fut, T>(
308        &self,
309        f: &F,
310    ) -> Result<(DatabaseTransaction, Result<T>)>
311    where
312        F: Send + Fn(TransactionHandle) -> Fut,
313        Fut: Send + Future<Output = Result<T>>,
314    {
315        let tx = self
316            .pool
317            .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
318            .await?;
319
320        let mut tx = Arc::new(Some(tx));
321        let result = f(TransactionHandle(tx.clone())).await;
322        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
323            return Err(anyhow!(
324                "couldn't complete transaction because it's still in use"
325            ))?;
326        };
327
328        Ok((tx, result))
329    }
330
331    async fn run<F, T>(&self, future: F) -> Result<T>
332    where
333        F: Future<Output = Result<T>>,
334    {
335        #[cfg(test)]
336        {
337            if let Executor::Deterministic(executor) = &self.executor {
338                executor.simulate_random_delay().await;
339            }
340
341            self.runtime.as_ref().unwrap().block_on(future)
342        }
343
344        #[cfg(not(test))]
345        {
346            future.await
347        }
348    }
349
350    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: usize) -> bool {
351        // If the error is due to a failure to serialize concurrent transactions, then retry
352        // this transaction after a delay. With each subsequent retry, double the delay duration.
353        // Also vary the delay randomly in order to ensure different database connections retry
354        // at different times.
355        const SLEEPS: [f32; 10] = [10., 20., 40., 80., 160., 320., 640., 1280., 2560., 5120.];
356        if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() {
357            let base_delay = SLEEPS[prev_attempt_count];
358            let randomized_delay = base_delay * self.rng.lock().await.gen_range(0.5..=2.0);
359            log::info!(
360                "retrying transaction after serialization error. delay: {} ms.",
361                randomized_delay
362            );
363            self.executor
364                .sleep(Duration::from_millis(randomized_delay as u64))
365                .await;
366            true
367        } else {
368            false
369        }
370    }
371}
372
373fn is_serialization_error(error: &Error) -> bool {
374    const SERIALIZATION_FAILURE_CODE: &str = "40001";
375    match error {
376        Error::Database(
377            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
378            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
379        ) if error
380            .as_database_error()
381            .and_then(|error| error.code())
382            .as_deref()
383            == Some(SERIALIZATION_FAILURE_CODE) =>
384        {
385            true
386        }
387        _ => false,
388    }
389}
390
391/// A handle to a [`DatabaseTransaction`].
392pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
393
394impl Deref for TransactionHandle {
395    type Target = DatabaseTransaction;
396
397    fn deref(&self) -> &Self::Target {
398        self.0.as_ref().as_ref().unwrap()
399    }
400}
401
402/// [`RoomGuard`] keeps a database transaction alive until it is dropped.
403/// so that updates to rooms are serialized.
404pub struct RoomGuard<T> {
405    data: T,
406    _guard: OwnedMutexGuard<()>,
407    _not_send: PhantomData<Rc<()>>,
408}
409
410impl<T> Deref for RoomGuard<T> {
411    type Target = T;
412
413    fn deref(&self) -> &T {
414        &self.data
415    }
416}
417
418impl<T> DerefMut for RoomGuard<T> {
419    fn deref_mut(&mut self) -> &mut T {
420        &mut self.data
421    }
422}
423
424impl<T> RoomGuard<T> {
425    /// Returns the inner value of the guard.
426    pub fn into_inner(self) -> T {
427        self.data
428    }
429}
430
431#[derive(Clone, Debug, PartialEq, Eq)]
432pub enum Contact {
433    Accepted { user_id: UserId, busy: bool },
434    Outgoing { user_id: UserId },
435    Incoming { user_id: UserId },
436}
437
438impl Contact {
439    pub fn user_id(&self) -> UserId {
440        match self {
441            Contact::Accepted { user_id, .. } => *user_id,
442            Contact::Outgoing { user_id } => *user_id,
443            Contact::Incoming { user_id, .. } => *user_id,
444        }
445    }
446}
447
448pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
449
450pub struct CreatedChannelMessage {
451    pub message_id: MessageId,
452    pub participant_connection_ids: Vec<ConnectionId>,
453    pub channel_members: Vec<UserId>,
454    pub notifications: NotificationBatch,
455}
456
457pub struct UpdatedChannelMessage {
458    pub message_id: MessageId,
459    pub participant_connection_ids: Vec<ConnectionId>,
460    pub notifications: NotificationBatch,
461    pub reply_to_message_id: Option<MessageId>,
462    pub timestamp: PrimitiveDateTime,
463    pub deleted_mention_notification_ids: Vec<NotificationId>,
464    pub updated_mention_notifications: Vec<rpc::proto::Notification>,
465}
466
467#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
468pub struct Invite {
469    pub email_address: String,
470    pub email_confirmation_code: String,
471}
472
473#[derive(Clone, Debug, Deserialize)]
474pub struct NewSignup {
475    pub email_address: String,
476    pub platform_mac: bool,
477    pub platform_windows: bool,
478    pub platform_linux: bool,
479    pub editor_features: Vec<String>,
480    pub programming_languages: Vec<String>,
481    pub device_id: Option<String>,
482    pub added_to_mailing_list: bool,
483    pub created_at: Option<DateTime>,
484}
485
486#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
487pub struct WaitlistSummary {
488    pub count: i64,
489    pub linux_count: i64,
490    pub mac_count: i64,
491    pub windows_count: i64,
492    pub unknown_count: i64,
493}
494
495/// The parameters to create a new user.
496#[derive(Debug, Serialize, Deserialize)]
497pub struct NewUserParams {
498    pub github_login: String,
499    pub github_user_id: i32,
500}
501
502/// The result of creating a new user.
503#[derive(Debug)]
504pub struct NewUserResult {
505    pub user_id: UserId,
506    pub metrics_id: String,
507    pub inviting_user_id: Option<UserId>,
508    pub signup_device_id: Option<String>,
509}
510
511/// The result of updating a channel membership.
512#[derive(Debug)]
513pub struct MembershipUpdated {
514    pub channel_id: ChannelId,
515    pub new_channels: ChannelsForUser,
516    pub removed_channels: Vec<ChannelId>,
517}
518
519/// The result of setting a member's role.
520#[derive(Debug)]
521pub enum SetMemberRoleResult {
522    InviteUpdated(Channel),
523    MembershipUpdated(MembershipUpdated),
524}
525
526/// The result of inviting a member to a channel.
527#[derive(Debug)]
528pub struct InviteMemberResult {
529    pub channel: Channel,
530    pub notifications: NotificationBatch,
531}
532
533#[derive(Debug)]
534pub struct RespondToChannelInvite {
535    pub membership_update: Option<MembershipUpdated>,
536    pub notifications: NotificationBatch,
537}
538
539#[derive(Debug)]
540pub struct RemoveChannelMemberResult {
541    pub membership_update: MembershipUpdated,
542    pub notification_id: Option<NotificationId>,
543}
544
545#[derive(Debug, PartialEq, Eq, Hash)]
546pub struct Channel {
547    pub id: ChannelId,
548    pub name: String,
549    pub visibility: ChannelVisibility,
550    /// parent_path is the channel ids from the root to this one (not including this one)
551    pub parent_path: Vec<ChannelId>,
552}
553
554impl Channel {
555    pub fn from_model(value: channel::Model) -> Self {
556        Channel {
557            id: value.id,
558            visibility: value.visibility,
559            name: value.clone().name,
560            parent_path: value.ancestors().collect(),
561        }
562    }
563
564    pub fn to_proto(&self) -> proto::Channel {
565        proto::Channel {
566            id: self.id.to_proto(),
567            name: self.name.clone(),
568            visibility: self.visibility.into(),
569            parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(),
570        }
571    }
572}
573
574#[derive(Debug, PartialEq, Eq, Hash)]
575pub struct ChannelMember {
576    pub role: ChannelRole,
577    pub user_id: UserId,
578    pub kind: proto::channel_member::Kind,
579}
580
581impl ChannelMember {
582    pub fn to_proto(&self) -> proto::ChannelMember {
583        proto::ChannelMember {
584            role: self.role.into(),
585            user_id: self.user_id.to_proto(),
586            kind: self.kind.into(),
587        }
588    }
589}
590
591#[derive(Debug, PartialEq)]
592pub struct ChannelsForUser {
593    pub channels: Vec<Channel>,
594    pub channel_memberships: Vec<channel_member::Model>,
595    pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
596    pub hosted_projects: Vec<proto::HostedProject>,
597
598    pub observed_buffer_versions: Vec<proto::ChannelBufferVersion>,
599    pub observed_channel_messages: Vec<proto::ChannelMessageId>,
600    pub latest_buffer_versions: Vec<proto::ChannelBufferVersion>,
601    pub latest_channel_messages: Vec<proto::ChannelMessageId>,
602}
603
604#[derive(Debug)]
605pub struct RejoinedChannelBuffer {
606    pub buffer: proto::RejoinedChannelBuffer,
607    pub old_connection_id: ConnectionId,
608}
609
610#[derive(Clone)]
611pub struct JoinRoom {
612    pub room: proto::Room,
613    pub channel: Option<channel::Model>,
614}
615
616pub struct RejoinedRoom {
617    pub room: proto::Room,
618    pub rejoined_projects: Vec<RejoinedProject>,
619    pub reshared_projects: Vec<ResharedProject>,
620    pub channel: Option<channel::Model>,
621}
622
623pub struct ResharedProject {
624    pub id: ProjectId,
625    pub old_connection_id: ConnectionId,
626    pub collaborators: Vec<ProjectCollaborator>,
627    pub worktrees: Vec<proto::WorktreeMetadata>,
628}
629
630pub struct RejoinedProject {
631    pub id: ProjectId,
632    pub old_connection_id: ConnectionId,
633    pub collaborators: Vec<ProjectCollaborator>,
634    pub worktrees: Vec<RejoinedWorktree>,
635    pub language_servers: Vec<proto::LanguageServer>,
636}
637
638#[derive(Debug)]
639pub struct RejoinedWorktree {
640    pub id: u64,
641    pub abs_path: String,
642    pub root_name: String,
643    pub visible: bool,
644    pub updated_entries: Vec<proto::Entry>,
645    pub removed_entries: Vec<u64>,
646    pub updated_repositories: Vec<proto::RepositoryEntry>,
647    pub removed_repositories: Vec<u64>,
648    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
649    pub settings_files: Vec<WorktreeSettingsFile>,
650    pub scan_id: u64,
651    pub completed_scan_id: u64,
652}
653
654pub struct LeftRoom {
655    pub room: proto::Room,
656    pub channel: Option<channel::Model>,
657    pub left_projects: HashMap<ProjectId, LeftProject>,
658    pub canceled_calls_to_user_ids: Vec<UserId>,
659    pub deleted: bool,
660}
661
662pub struct RefreshedRoom {
663    pub room: proto::Room,
664    pub channel: Option<channel::Model>,
665    pub stale_participant_user_ids: Vec<UserId>,
666    pub canceled_calls_to_user_ids: Vec<UserId>,
667}
668
669pub struct RefreshedChannelBuffer {
670    pub connection_ids: Vec<ConnectionId>,
671    pub collaborators: Vec<proto::Collaborator>,
672}
673
674pub struct Project {
675    pub id: ProjectId,
676    pub role: ChannelRole,
677    pub collaborators: Vec<ProjectCollaborator>,
678    pub worktrees: BTreeMap<u64, Worktree>,
679    pub language_servers: Vec<proto::LanguageServer>,
680}
681
682pub struct ProjectCollaborator {
683    pub connection_id: ConnectionId,
684    pub user_id: UserId,
685    pub replica_id: ReplicaId,
686    pub is_host: bool,
687}
688
689impl ProjectCollaborator {
690    pub fn to_proto(&self) -> proto::Collaborator {
691        proto::Collaborator {
692            peer_id: Some(self.connection_id.into()),
693            replica_id: self.replica_id.0 as u32,
694            user_id: self.user_id.to_proto(),
695        }
696    }
697}
698
699#[derive(Debug)]
700pub struct LeftProject {
701    pub id: ProjectId,
702    pub host_user_id: Option<UserId>,
703    pub host_connection_id: Option<ConnectionId>,
704    pub connection_ids: Vec<ConnectionId>,
705}
706
707pub struct Worktree {
708    pub id: u64,
709    pub abs_path: String,
710    pub root_name: String,
711    pub visible: bool,
712    pub entries: Vec<proto::Entry>,
713    pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
714    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
715    pub settings_files: Vec<WorktreeSettingsFile>,
716    pub scan_id: u64,
717    pub completed_scan_id: u64,
718}
719
720#[derive(Debug)]
721pub struct WorktreeSettingsFile {
722    pub path: String,
723    pub content: String,
724}
725
726pub struct NewExtensionVersion {
727    pub name: String,
728    pub version: semver::Version,
729    pub description: String,
730    pub authors: Vec<String>,
731    pub repository: String,
732    pub schema_version: i32,
733    pub wasm_api_version: Option<String>,
734    pub published_at: PrimitiveDateTime,
735}
736
737pub struct ExtensionVersionConstraints {
738    pub schema_versions: RangeInclusive<i32>,
739    pub wasm_api_versions: RangeInclusive<SemanticVersion>,
740}