db.rs

  1#[cfg(test)]
  2mod db_tests;
  3#[cfg(test)]
  4pub mod test_db;
  5
  6mod ids;
  7mod queries;
  8mod tables;
  9
 10use crate::{executor::Executor, Error, Result};
 11use anyhow::anyhow;
 12use collections::{BTreeMap, HashMap, HashSet};
 13use dashmap::DashMap;
 14use futures::StreamExt;
 15use rand::{prelude::StdRng, Rng, SeedableRng};
 16use rpc::{proto, ConnectionId};
 17use sea_orm::{
 18    entity::prelude::*, ActiveValue, Condition, ConnectionTrait, DatabaseConnection,
 19    DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType,
 20    QueryOrder, QuerySelect, Statement, TransactionTrait,
 21};
 22use sea_query::{Alias, Expr, OnConflict, Query};
 23use serde::{Deserialize, Serialize};
 24use sqlx::{
 25    migrate::{Migrate, Migration, MigrationSource},
 26    Connection,
 27};
 28use std::{
 29    fmt::Write as _,
 30    future::Future,
 31    marker::PhantomData,
 32    ops::{Deref, DerefMut},
 33    path::Path,
 34    rc::Rc,
 35    sync::Arc,
 36    time::Duration,
 37};
 38use tables::*;
 39use tokio::sync::{Mutex, OwnedMutexGuard};
 40
 41pub use ids::*;
 42pub use sea_orm::ConnectOptions;
 43pub use tables::user::Model as User;
 44
 45pub struct Database {
 46    options: ConnectOptions,
 47    pool: DatabaseConnection,
 48    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 49    rng: Mutex<StdRng>,
 50    executor: Executor,
 51    #[cfg(test)]
 52    runtime: Option<tokio::runtime::Runtime>,
 53}
 54
 55// The `Database` type has so many methods that its impl blocks are split into
 56// separate files in the `queries` folder.
 57impl Database {
 58    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 59        Ok(Self {
 60            options: options.clone(),
 61            pool: sea_orm::Database::connect(options).await?,
 62            rooms: DashMap::with_capacity(16384),
 63            rng: Mutex::new(StdRng::seed_from_u64(0)),
 64            executor,
 65            #[cfg(test)]
 66            runtime: None,
 67        })
 68    }
 69
 70    #[cfg(test)]
 71    pub fn reset(&self) {
 72        self.rooms.clear();
 73    }
 74
 75    pub async fn migrate(
 76        &self,
 77        migrations_path: &Path,
 78        ignore_checksum_mismatch: bool,
 79    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 80        let migrations = MigrationSource::resolve(migrations_path)
 81            .await
 82            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 83
 84        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
 85
 86        connection.ensure_migrations_table().await?;
 87        let applied_migrations: HashMap<_, _> = connection
 88            .list_applied_migrations()
 89            .await?
 90            .into_iter()
 91            .map(|m| (m.version, m))
 92            .collect();
 93
 94        let mut new_migrations = Vec::new();
 95        for migration in migrations {
 96            match applied_migrations.get(&migration.version) {
 97                Some(applied_migration) => {
 98                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 99                    {
100                        Err(anyhow!(
101                            "checksum mismatch for applied migration {}",
102                            migration.description
103                        ))?;
104                    }
105                }
106                None => {
107                    let elapsed = connection.apply(&migration).await?;
108                    new_migrations.push((migration, elapsed));
109                }
110            }
111        }
112
113        Ok(new_migrations)
114    }
115
116    async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
117    where
118        F: Send + Fn(TransactionHandle) -> Fut,
119        Fut: Send + Future<Output = Result<T>>,
120    {
121        let body = async {
122            let mut i = 0;
123            loop {
124                let (tx, result) = self.with_transaction(&f).await?;
125                match result {
126                    Ok(result) => match tx.commit().await.map_err(Into::into) {
127                        Ok(()) => return Ok(result),
128                        Err(error) => {
129                            if !self.retry_on_serialization_error(&error, i).await {
130                                return Err(error);
131                            }
132                        }
133                    },
134                    Err(error) => {
135                        tx.rollback().await?;
136                        if !self.retry_on_serialization_error(&error, i).await {
137                            return Err(error);
138                        }
139                    }
140                }
141                i += 1;
142            }
143        };
144
145        self.run(body).await
146    }
147
148    async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
149    where
150        F: Send + Fn(TransactionHandle) -> Fut,
151        Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
152    {
153        let body = async {
154            let mut i = 0;
155            loop {
156                let (tx, result) = self.with_transaction(&f).await?;
157                match result {
158                    Ok(Some((room_id, data))) => {
159                        let lock = self.rooms.entry(room_id).or_default().clone();
160                        let _guard = lock.lock_owned().await;
161                        match tx.commit().await.map_err(Into::into) {
162                            Ok(()) => {
163                                return Ok(Some(RoomGuard {
164                                    data,
165                                    _guard,
166                                    _not_send: PhantomData,
167                                }));
168                            }
169                            Err(error) => {
170                                if !self.retry_on_serialization_error(&error, i).await {
171                                    return Err(error);
172                                }
173                            }
174                        }
175                    }
176                    Ok(None) => match tx.commit().await.map_err(Into::into) {
177                        Ok(()) => return Ok(None),
178                        Err(error) => {
179                            if !self.retry_on_serialization_error(&error, i).await {
180                                return Err(error);
181                            }
182                        }
183                    },
184                    Err(error) => {
185                        tx.rollback().await?;
186                        if !self.retry_on_serialization_error(&error, i).await {
187                            return Err(error);
188                        }
189                    }
190                }
191                i += 1;
192            }
193        };
194
195        self.run(body).await
196    }
197
198    async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
199    where
200        F: Send + Fn(TransactionHandle) -> Fut,
201        Fut: Send + Future<Output = Result<T>>,
202    {
203        let body = async {
204            let mut i = 0;
205            loop {
206                let lock = self.rooms.entry(room_id).or_default().clone();
207                let _guard = lock.lock_owned().await;
208                let (tx, result) = self.with_transaction(&f).await?;
209                match result {
210                    Ok(data) => match tx.commit().await.map_err(Into::into) {
211                        Ok(()) => {
212                            return Ok(RoomGuard {
213                                data,
214                                _guard,
215                                _not_send: PhantomData,
216                            });
217                        }
218                        Err(error) => {
219                            if !self.retry_on_serialization_error(&error, i).await {
220                                return Err(error);
221                            }
222                        }
223                    },
224                    Err(error) => {
225                        tx.rollback().await?;
226                        if !self.retry_on_serialization_error(&error, i).await {
227                            return Err(error);
228                        }
229                    }
230                }
231                i += 1;
232            }
233        };
234
235        self.run(body).await
236    }
237
238    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
239    where
240        F: Send + Fn(TransactionHandle) -> Fut,
241        Fut: Send + Future<Output = Result<T>>,
242    {
243        let tx = self
244            .pool
245            .begin_with_config(Some(IsolationLevel::Serializable), None)
246            .await?;
247
248        let mut tx = Arc::new(Some(tx));
249        let result = f(TransactionHandle(tx.clone())).await;
250        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
251            return Err(anyhow!("couldn't complete transaction because it's still in use"))?;
252        };
253
254        Ok((tx, result))
255    }
256
257    async fn run<F, T>(&self, future: F) -> Result<T>
258    where
259        F: Future<Output = Result<T>>,
260    {
261        #[cfg(test)]
262        {
263            if let Executor::Deterministic(executor) = &self.executor {
264                executor.simulate_random_delay().await;
265            }
266
267            self.runtime.as_ref().unwrap().block_on(future)
268        }
269
270        #[cfg(not(test))]
271        {
272            future.await
273        }
274    }
275
276    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
277        // If the error is due to a failure to serialize concurrent transactions, then retry
278        // this transaction after a delay. With each subsequent retry, double the delay duration.
279        // Also vary the delay randomly in order to ensure different database connections retry
280        // at different times.
281        if is_serialization_error(error) {
282            let base_delay = 4_u64 << prev_attempt_count.min(16);
283            let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
284            log::info!(
285                "retrying transaction after serialization error. delay: {} ms.",
286                randomized_delay
287            );
288            self.executor
289                .sleep(Duration::from_millis(randomized_delay as u64))
290                .await;
291            true
292        } else {
293            false
294        }
295    }
296}
297
298fn is_serialization_error(error: &Error) -> bool {
299    const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
300    match error {
301        Error::Database(
302            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
303            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
304        ) if error
305            .as_database_error()
306            .and_then(|error| error.code())
307            .as_deref()
308            == Some(SERIALIZATION_FAILURE_CODE) =>
309        {
310            true
311        }
312        _ => false,
313    }
314}
315
316struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
317
318impl Deref for TransactionHandle {
319    type Target = DatabaseTransaction;
320
321    fn deref(&self) -> &Self::Target {
322        self.0.as_ref().as_ref().unwrap()
323    }
324}
325
326pub struct RoomGuard<T> {
327    data: T,
328    _guard: OwnedMutexGuard<()>,
329    _not_send: PhantomData<Rc<()>>,
330}
331
332impl<T> Deref for RoomGuard<T> {
333    type Target = T;
334
335    fn deref(&self) -> &T {
336        &self.data
337    }
338}
339
340impl<T> DerefMut for RoomGuard<T> {
341    fn deref_mut(&mut self) -> &mut T {
342        &mut self.data
343    }
344}
345
346impl<T> RoomGuard<T> {
347    pub fn into_inner(self) -> T {
348        self.data
349    }
350}
351
352#[derive(Clone, Debug, PartialEq, Eq)]
353pub enum Contact {
354    Accepted {
355        user_id: UserId,
356        should_notify: bool,
357        busy: bool,
358    },
359    Outgoing {
360        user_id: UserId,
361    },
362    Incoming {
363        user_id: UserId,
364        should_notify: bool,
365    },
366}
367
368impl Contact {
369    pub fn user_id(&self) -> UserId {
370        match self {
371            Contact::Accepted { user_id, .. } => *user_id,
372            Contact::Outgoing { user_id } => *user_id,
373            Contact::Incoming { user_id, .. } => *user_id,
374        }
375    }
376}
377
378#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
379pub struct Invite {
380    pub email_address: String,
381    pub email_confirmation_code: String,
382}
383
384#[derive(Clone, Debug, Deserialize)]
385pub struct NewSignup {
386    pub email_address: String,
387    pub platform_mac: bool,
388    pub platform_windows: bool,
389    pub platform_linux: bool,
390    pub editor_features: Vec<String>,
391    pub programming_languages: Vec<String>,
392    pub device_id: Option<String>,
393    pub added_to_mailing_list: bool,
394    pub created_at: Option<DateTime>,
395}
396
397#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
398pub struct WaitlistSummary {
399    pub count: i64,
400    pub linux_count: i64,
401    pub mac_count: i64,
402    pub windows_count: i64,
403    pub unknown_count: i64,
404}
405
406#[derive(Debug, Serialize, Deserialize)]
407pub struct NewUserParams {
408    pub github_login: String,
409    pub github_user_id: i32,
410    pub invite_count: i32,
411}
412
413#[derive(Debug)]
414pub struct NewUserResult {
415    pub user_id: UserId,
416    pub metrics_id: String,
417    pub inviting_user_id: Option<UserId>,
418    pub signup_device_id: Option<String>,
419}
420
421#[derive(FromQueryResult, Debug, PartialEq)]
422pub struct Channel {
423    pub id: ChannelId,
424    pub name: String,
425    pub parent_id: Option<ChannelId>,
426}
427
428#[derive(Debug, PartialEq)]
429pub struct ChannelsForUser {
430    pub channels: Vec<Channel>,
431    pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
432    pub channels_with_admin_privileges: HashSet<ChannelId>,
433}
434
435#[derive(Clone)]
436pub struct JoinRoom {
437    pub room: proto::Room,
438    pub channel_id: Option<ChannelId>,
439    pub channel_members: Vec<UserId>,
440}
441
442pub struct RejoinedRoom {
443    pub room: proto::Room,
444    pub rejoined_projects: Vec<RejoinedProject>,
445    pub reshared_projects: Vec<ResharedProject>,
446    pub channel_id: Option<ChannelId>,
447    pub channel_members: Vec<UserId>,
448}
449
450pub struct ResharedProject {
451    pub id: ProjectId,
452    pub old_connection_id: ConnectionId,
453    pub collaborators: Vec<ProjectCollaborator>,
454    pub worktrees: Vec<proto::WorktreeMetadata>,
455}
456
457pub struct RejoinedProject {
458    pub id: ProjectId,
459    pub old_connection_id: ConnectionId,
460    pub collaborators: Vec<ProjectCollaborator>,
461    pub worktrees: Vec<RejoinedWorktree>,
462    pub language_servers: Vec<proto::LanguageServer>,
463}
464
465#[derive(Debug)]
466pub struct RejoinedWorktree {
467    pub id: u64,
468    pub abs_path: String,
469    pub root_name: String,
470    pub visible: bool,
471    pub updated_entries: Vec<proto::Entry>,
472    pub removed_entries: Vec<u64>,
473    pub updated_repositories: Vec<proto::RepositoryEntry>,
474    pub removed_repositories: Vec<u64>,
475    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
476    pub settings_files: Vec<WorktreeSettingsFile>,
477    pub scan_id: u64,
478    pub completed_scan_id: u64,
479}
480
481pub struct LeftRoom {
482    pub room: proto::Room,
483    pub channel_id: Option<ChannelId>,
484    pub channel_members: Vec<UserId>,
485    pub left_projects: HashMap<ProjectId, LeftProject>,
486    pub canceled_calls_to_user_ids: Vec<UserId>,
487    pub deleted: bool,
488}
489
490pub struct RefreshedRoom {
491    pub room: proto::Room,
492    pub channel_id: Option<ChannelId>,
493    pub channel_members: Vec<UserId>,
494    pub stale_participant_user_ids: Vec<UserId>,
495    pub canceled_calls_to_user_ids: Vec<UserId>,
496}
497
498pub struct Project {
499    pub collaborators: Vec<ProjectCollaborator>,
500    pub worktrees: BTreeMap<u64, Worktree>,
501    pub language_servers: Vec<proto::LanguageServer>,
502}
503
504pub struct ProjectCollaborator {
505    pub connection_id: ConnectionId,
506    pub user_id: UserId,
507    pub replica_id: ReplicaId,
508    pub is_host: bool,
509}
510
511impl ProjectCollaborator {
512    pub fn to_proto(&self) -> proto::Collaborator {
513        proto::Collaborator {
514            peer_id: Some(self.connection_id.into()),
515            replica_id: self.replica_id.0 as u32,
516            user_id: self.user_id.to_proto(),
517        }
518    }
519}
520
521#[derive(Debug)]
522pub struct LeftProject {
523    pub id: ProjectId,
524    pub host_user_id: UserId,
525    pub host_connection_id: ConnectionId,
526    pub connection_ids: Vec<ConnectionId>,
527}
528
529pub struct Worktree {
530    pub id: u64,
531    pub abs_path: String,
532    pub root_name: String,
533    pub visible: bool,
534    pub entries: Vec<proto::Entry>,
535    pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
536    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
537    pub settings_files: Vec<WorktreeSettingsFile>,
538    pub scan_id: u64,
539    pub completed_scan_id: u64,
540}
541
542#[derive(Debug)]
543pub struct WorktreeSettingsFile {
544    pub path: String,
545    pub content: String,
546}