db.rs

  1#[cfg(test)]
  2mod db_tests;
  3#[cfg(test)]
  4pub mod test_db;
  5
  6mod ids;
  7mod queries;
  8mod tables;
  9
 10use crate::{executor::Executor, Error, Result};
 11use anyhow::anyhow;
 12use collections::{BTreeMap, HashMap, HashSet};
 13use dashmap::DashMap;
 14use futures::StreamExt;
 15use rand::{prelude::StdRng, Rng, SeedableRng};
 16use rpc::{proto, ConnectionId};
 17use sea_orm::{
 18    entity::prelude::*, ActiveValue, Condition, ConnectionTrait, DatabaseConnection,
 19    DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType,
 20    QueryOrder, QuerySelect, Statement, TransactionTrait,
 21};
 22use sea_query::{Alias, Expr, OnConflict, Query};
 23use serde::{Deserialize, Serialize};
 24use sqlx::{
 25    migrate::{Migrate, Migration, MigrationSource},
 26    Connection,
 27};
 28use std::{
 29    fmt::Write as _,
 30    future::Future,
 31    marker::PhantomData,
 32    ops::{Deref, DerefMut},
 33    path::Path,
 34    rc::Rc,
 35    sync::Arc,
 36    time::Duration,
 37};
 38use tables::*;
 39use tokio::sync::{Mutex, OwnedMutexGuard};
 40
 41pub use ids::*;
 42pub use sea_orm::ConnectOptions;
 43pub use tables::user::Model as User;
 44
 45pub struct Database {
 46    options: ConnectOptions,
 47    pool: DatabaseConnection,
 48    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 49    rng: Mutex<StdRng>,
 50    executor: Executor,
 51    #[cfg(test)]
 52    runtime: Option<tokio::runtime::Runtime>,
 53}
 54
 55impl Database {
 56    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 57        Ok(Self {
 58            options: options.clone(),
 59            pool: sea_orm::Database::connect(options).await?,
 60            rooms: DashMap::with_capacity(16384),
 61            rng: Mutex::new(StdRng::seed_from_u64(0)),
 62            executor,
 63            #[cfg(test)]
 64            runtime: None,
 65        })
 66    }
 67
 68    #[cfg(test)]
 69    pub fn reset(&self) {
 70        self.rooms.clear();
 71    }
 72
 73    pub async fn migrate(
 74        &self,
 75        migrations_path: &Path,
 76        ignore_checksum_mismatch: bool,
 77    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 78        let migrations = MigrationSource::resolve(migrations_path)
 79            .await
 80            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 81
 82        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
 83
 84        connection.ensure_migrations_table().await?;
 85        let applied_migrations: HashMap<_, _> = connection
 86            .list_applied_migrations()
 87            .await?
 88            .into_iter()
 89            .map(|m| (m.version, m))
 90            .collect();
 91
 92        let mut new_migrations = Vec::new();
 93        for migration in migrations {
 94            match applied_migrations.get(&migration.version) {
 95                Some(applied_migration) => {
 96                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 97                    {
 98                        Err(anyhow!(
 99                            "checksum mismatch for applied migration {}",
100                            migration.description
101                        ))?;
102                    }
103                }
104                None => {
105                    let elapsed = connection.apply(&migration).await?;
106                    new_migrations.push((migration, elapsed));
107                }
108            }
109        }
110
111        Ok(new_migrations)
112    }
113
114    async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
115    where
116        F: Send + Fn(TransactionHandle) -> Fut,
117        Fut: Send + Future<Output = Result<T>>,
118    {
119        let body = async {
120            let mut i = 0;
121            loop {
122                let (tx, result) = self.with_transaction(&f).await?;
123                match result {
124                    Ok(result) => match tx.commit().await.map_err(Into::into) {
125                        Ok(()) => return Ok(result),
126                        Err(error) => {
127                            if !self.retry_on_serialization_error(&error, i).await {
128                                return Err(error);
129                            }
130                        }
131                    },
132                    Err(error) => {
133                        tx.rollback().await?;
134                        if !self.retry_on_serialization_error(&error, i).await {
135                            return Err(error);
136                        }
137                    }
138                }
139                i += 1;
140            }
141        };
142
143        self.run(body).await
144    }
145
146    async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
147    where
148        F: Send + Fn(TransactionHandle) -> Fut,
149        Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
150    {
151        let body = async {
152            let mut i = 0;
153            loop {
154                let (tx, result) = self.with_transaction(&f).await?;
155                match result {
156                    Ok(Some((room_id, data))) => {
157                        let lock = self.rooms.entry(room_id).or_default().clone();
158                        let _guard = lock.lock_owned().await;
159                        match tx.commit().await.map_err(Into::into) {
160                            Ok(()) => {
161                                return Ok(Some(RoomGuard {
162                                    data,
163                                    _guard,
164                                    _not_send: PhantomData,
165                                }));
166                            }
167                            Err(error) => {
168                                if !self.retry_on_serialization_error(&error, i).await {
169                                    return Err(error);
170                                }
171                            }
172                        }
173                    }
174                    Ok(None) => match tx.commit().await.map_err(Into::into) {
175                        Ok(()) => return Ok(None),
176                        Err(error) => {
177                            if !self.retry_on_serialization_error(&error, i).await {
178                                return Err(error);
179                            }
180                        }
181                    },
182                    Err(error) => {
183                        tx.rollback().await?;
184                        if !self.retry_on_serialization_error(&error, i).await {
185                            return Err(error);
186                        }
187                    }
188                }
189                i += 1;
190            }
191        };
192
193        self.run(body).await
194    }
195
196    async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
197    where
198        F: Send + Fn(TransactionHandle) -> Fut,
199        Fut: Send + Future<Output = Result<T>>,
200    {
201        let body = async {
202            let mut i = 0;
203            loop {
204                let lock = self.rooms.entry(room_id).or_default().clone();
205                let _guard = lock.lock_owned().await;
206                let (tx, result) = self.with_transaction(&f).await?;
207                match result {
208                    Ok(data) => match tx.commit().await.map_err(Into::into) {
209                        Ok(()) => {
210                            return Ok(RoomGuard {
211                                data,
212                                _guard,
213                                _not_send: PhantomData,
214                            });
215                        }
216                        Err(error) => {
217                            if !self.retry_on_serialization_error(&error, i).await {
218                                return Err(error);
219                            }
220                        }
221                    },
222                    Err(error) => {
223                        tx.rollback().await?;
224                        if !self.retry_on_serialization_error(&error, i).await {
225                            return Err(error);
226                        }
227                    }
228                }
229                i += 1;
230            }
231        };
232
233        self.run(body).await
234    }
235
236    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
237    where
238        F: Send + Fn(TransactionHandle) -> Fut,
239        Fut: Send + Future<Output = Result<T>>,
240    {
241        let tx = self
242            .pool
243            .begin_with_config(Some(IsolationLevel::Serializable), None)
244            .await?;
245
246        let mut tx = Arc::new(Some(tx));
247        let result = f(TransactionHandle(tx.clone())).await;
248        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
249            return Err(anyhow!("couldn't complete transaction because it's still in use"))?;
250        };
251
252        Ok((tx, result))
253    }
254
255    async fn run<F, T>(&self, future: F) -> Result<T>
256    where
257        F: Future<Output = Result<T>>,
258    {
259        #[cfg(test)]
260        {
261            if let Executor::Deterministic(executor) = &self.executor {
262                executor.simulate_random_delay().await;
263            }
264
265            self.runtime.as_ref().unwrap().block_on(future)
266        }
267
268        #[cfg(not(test))]
269        {
270            future.await
271        }
272    }
273
274    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
275        // If the error is due to a failure to serialize concurrent transactions, then retry
276        // this transaction after a delay. With each subsequent retry, double the delay duration.
277        // Also vary the delay randomly in order to ensure different database connections retry
278        // at different times.
279        if is_serialization_error(error) {
280            let base_delay = 4_u64 << prev_attempt_count.min(16);
281            let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
282            log::info!(
283                "retrying transaction after serialization error. delay: {} ms.",
284                randomized_delay
285            );
286            self.executor
287                .sleep(Duration::from_millis(randomized_delay as u64))
288                .await;
289            true
290        } else {
291            false
292        }
293    }
294}
295
296fn is_serialization_error(error: &Error) -> bool {
297    const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
298    match error {
299        Error::Database(
300            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
301            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
302        ) if error
303            .as_database_error()
304            .and_then(|error| error.code())
305            .as_deref()
306            == Some(SERIALIZATION_FAILURE_CODE) =>
307        {
308            true
309        }
310        _ => false,
311    }
312}
313
314struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
315
316impl Deref for TransactionHandle {
317    type Target = DatabaseTransaction;
318
319    fn deref(&self) -> &Self::Target {
320        self.0.as_ref().as_ref().unwrap()
321    }
322}
323
324pub struct RoomGuard<T> {
325    data: T,
326    _guard: OwnedMutexGuard<()>,
327    _not_send: PhantomData<Rc<()>>,
328}
329
330impl<T> Deref for RoomGuard<T> {
331    type Target = T;
332
333    fn deref(&self) -> &T {
334        &self.data
335    }
336}
337
338impl<T> DerefMut for RoomGuard<T> {
339    fn deref_mut(&mut self) -> &mut T {
340        &mut self.data
341    }
342}
343
344impl<T> RoomGuard<T> {
345    pub fn into_inner(self) -> T {
346        self.data
347    }
348}
349
350#[derive(Clone, Debug, PartialEq, Eq)]
351pub enum Contact {
352    Accepted {
353        user_id: UserId,
354        should_notify: bool,
355        busy: bool,
356    },
357    Outgoing {
358        user_id: UserId,
359    },
360    Incoming {
361        user_id: UserId,
362        should_notify: bool,
363    },
364}
365
366impl Contact {
367    pub fn user_id(&self) -> UserId {
368        match self {
369            Contact::Accepted { user_id, .. } => *user_id,
370            Contact::Outgoing { user_id } => *user_id,
371            Contact::Incoming { user_id, .. } => *user_id,
372        }
373    }
374}
375
376#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
377pub struct Invite {
378    pub email_address: String,
379    pub email_confirmation_code: String,
380}
381
382#[derive(Clone, Debug, Deserialize)]
383pub struct NewSignup {
384    pub email_address: String,
385    pub platform_mac: bool,
386    pub platform_windows: bool,
387    pub platform_linux: bool,
388    pub editor_features: Vec<String>,
389    pub programming_languages: Vec<String>,
390    pub device_id: Option<String>,
391    pub added_to_mailing_list: bool,
392    pub created_at: Option<DateTime>,
393}
394
395#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
396pub struct WaitlistSummary {
397    pub count: i64,
398    pub linux_count: i64,
399    pub mac_count: i64,
400    pub windows_count: i64,
401    pub unknown_count: i64,
402}
403
404#[derive(Debug, Serialize, Deserialize)]
405pub struct NewUserParams {
406    pub github_login: String,
407    pub github_user_id: i32,
408    pub invite_count: i32,
409}
410
411#[derive(Debug)]
412pub struct NewUserResult {
413    pub user_id: UserId,
414    pub metrics_id: String,
415    pub inviting_user_id: Option<UserId>,
416    pub signup_device_id: Option<String>,
417}
418
419#[derive(FromQueryResult, Debug, PartialEq)]
420pub struct Channel {
421    pub id: ChannelId,
422    pub name: String,
423    pub parent_id: Option<ChannelId>,
424}
425
426#[derive(Debug, PartialEq)]
427pub struct ChannelsForUser {
428    pub channels: Vec<Channel>,
429    pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
430    pub channels_with_admin_privileges: HashSet<ChannelId>,
431}
432
433#[derive(Clone)]
434pub struct JoinRoom {
435    pub room: proto::Room,
436    pub channel_id: Option<ChannelId>,
437    pub channel_members: Vec<UserId>,
438}
439
440pub struct RejoinedRoom {
441    pub room: proto::Room,
442    pub rejoined_projects: Vec<RejoinedProject>,
443    pub reshared_projects: Vec<ResharedProject>,
444    pub channel_id: Option<ChannelId>,
445    pub channel_members: Vec<UserId>,
446}
447
448pub struct ResharedProject {
449    pub id: ProjectId,
450    pub old_connection_id: ConnectionId,
451    pub collaborators: Vec<ProjectCollaborator>,
452    pub worktrees: Vec<proto::WorktreeMetadata>,
453}
454
455pub struct RejoinedProject {
456    pub id: ProjectId,
457    pub old_connection_id: ConnectionId,
458    pub collaborators: Vec<ProjectCollaborator>,
459    pub worktrees: Vec<RejoinedWorktree>,
460    pub language_servers: Vec<proto::LanguageServer>,
461}
462
463#[derive(Debug)]
464pub struct RejoinedWorktree {
465    pub id: u64,
466    pub abs_path: String,
467    pub root_name: String,
468    pub visible: bool,
469    pub updated_entries: Vec<proto::Entry>,
470    pub removed_entries: Vec<u64>,
471    pub updated_repositories: Vec<proto::RepositoryEntry>,
472    pub removed_repositories: Vec<u64>,
473    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
474    pub settings_files: Vec<WorktreeSettingsFile>,
475    pub scan_id: u64,
476    pub completed_scan_id: u64,
477}
478
479pub struct LeftRoom {
480    pub room: proto::Room,
481    pub channel_id: Option<ChannelId>,
482    pub channel_members: Vec<UserId>,
483    pub left_projects: HashMap<ProjectId, LeftProject>,
484    pub canceled_calls_to_user_ids: Vec<UserId>,
485    pub deleted: bool,
486}
487
488pub struct RefreshedRoom {
489    pub room: proto::Room,
490    pub channel_id: Option<ChannelId>,
491    pub channel_members: Vec<UserId>,
492    pub stale_participant_user_ids: Vec<UserId>,
493    pub canceled_calls_to_user_ids: Vec<UserId>,
494}
495
496pub struct Project {
497    pub collaborators: Vec<ProjectCollaborator>,
498    pub worktrees: BTreeMap<u64, Worktree>,
499    pub language_servers: Vec<proto::LanguageServer>,
500}
501
502pub struct ProjectCollaborator {
503    pub connection_id: ConnectionId,
504    pub user_id: UserId,
505    pub replica_id: ReplicaId,
506    pub is_host: bool,
507}
508
509impl ProjectCollaborator {
510    pub fn to_proto(&self) -> proto::Collaborator {
511        proto::Collaborator {
512            peer_id: Some(self.connection_id.into()),
513            replica_id: self.replica_id.0 as u32,
514            user_id: self.user_id.to_proto(),
515        }
516    }
517}
518
519#[derive(Debug)]
520pub struct LeftProject {
521    pub id: ProjectId,
522    pub host_user_id: UserId,
523    pub host_connection_id: ConnectionId,
524    pub connection_ids: Vec<ConnectionId>,
525}
526
527pub struct Worktree {
528    pub id: u64,
529    pub abs_path: String,
530    pub root_name: String,
531    pub visible: bool,
532    pub entries: Vec<proto::Entry>,
533    pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
534    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
535    pub settings_files: Vec<WorktreeSettingsFile>,
536    pub scan_id: u64,
537    pub completed_scan_id: u64,
538}
539
540#[derive(Debug)]
541pub struct WorktreeSettingsFile {
542    pub path: String,
543    pub content: String,
544}