db.rs

  1#[cfg(test)]
  2pub mod tests;
  3
  4#[cfg(test)]
  5pub use tests::TestDb;
  6
  7mod ids;
  8mod queries;
  9mod tables;
 10
 11use crate::{executor::Executor, Error, Result};
 12use anyhow::anyhow;
 13use collections::{BTreeMap, HashMap, HashSet};
 14use dashmap::DashMap;
 15use futures::StreamExt;
 16use rand::{prelude::StdRng, Rng, SeedableRng};
 17use rpc::{
 18    proto::{self},
 19    ConnectionId,
 20};
 21use sea_orm::{
 22    entity::prelude::*, ActiveValue, Condition, ConnectionTrait, DatabaseConnection,
 23    DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType,
 24    QueryOrder, QuerySelect, Statement, TransactionTrait,
 25};
 26use sea_query::{Alias, Expr, OnConflict, Query};
 27use serde::{Deserialize, Serialize};
 28use sqlx::{
 29    migrate::{Migrate, Migration, MigrationSource},
 30    Connection,
 31};
 32use std::{
 33    fmt::Write as _,
 34    future::Future,
 35    marker::PhantomData,
 36    ops::{Deref, DerefMut},
 37    path::Path,
 38    rc::Rc,
 39    sync::Arc,
 40    time::Duration,
 41};
 42use tables::*;
 43use tokio::sync::{Mutex, OwnedMutexGuard};
 44
 45pub use ids::*;
 46pub use sea_orm::ConnectOptions;
 47pub use tables::user::Model as User;
 48
 49use self::queries::channels::ChannelGraph;
 50
 51pub struct Database {
 52    options: ConnectOptions,
 53    pool: DatabaseConnection,
 54    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 55    rng: Mutex<StdRng>,
 56    executor: Executor,
 57    #[cfg(test)]
 58    runtime: Option<tokio::runtime::Runtime>,
 59}
 60
 61// The `Database` type has so many methods that its impl blocks are split into
 62// separate files in the `queries` folder.
 63impl Database {
 64    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
 65        Ok(Self {
 66            options: options.clone(),
 67            pool: sea_orm::Database::connect(options).await?,
 68            rooms: DashMap::with_capacity(16384),
 69            rng: Mutex::new(StdRng::seed_from_u64(0)),
 70            executor,
 71            #[cfg(test)]
 72            runtime: None,
 73        })
 74    }
 75
 76    #[cfg(test)]
 77    pub fn reset(&self) {
 78        self.rooms.clear();
 79    }
 80
 81    pub async fn migrate(
 82        &self,
 83        migrations_path: &Path,
 84        ignore_checksum_mismatch: bool,
 85    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 86        let migrations = MigrationSource::resolve(migrations_path)
 87            .await
 88            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 89
 90        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
 91
 92        connection.ensure_migrations_table().await?;
 93        let applied_migrations: HashMap<_, _> = connection
 94            .list_applied_migrations()
 95            .await?
 96            .into_iter()
 97            .map(|m| (m.version, m))
 98            .collect();
 99
100        let mut new_migrations = Vec::new();
101        for migration in migrations {
102            match applied_migrations.get(&migration.version) {
103                Some(applied_migration) => {
104                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
105                    {
106                        Err(anyhow!(
107                            "checksum mismatch for applied migration {}",
108                            migration.description
109                        ))?;
110                    }
111                }
112                None => {
113                    let elapsed = connection.apply(&migration).await?;
114                    new_migrations.push((migration, elapsed));
115                }
116            }
117        }
118
119        Ok(new_migrations)
120    }
121
122    async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
123    where
124        F: Send + Fn(TransactionHandle) -> Fut,
125        Fut: Send + Future<Output = Result<T>>,
126    {
127        let body = async {
128            let mut i = 0;
129            loop {
130                let (tx, result) = self.with_transaction(&f).await?;
131                match result {
132                    Ok(result) => match tx.commit().await.map_err(Into::into) {
133                        Ok(()) => return Ok(result),
134                        Err(error) => {
135                            if !self.retry_on_serialization_error(&error, i).await {
136                                return Err(error);
137                            }
138                        }
139                    },
140                    Err(error) => {
141                        tx.rollback().await?;
142                        if !self.retry_on_serialization_error(&error, i).await {
143                            return Err(error);
144                        }
145                    }
146                }
147                i += 1;
148            }
149        };
150
151        self.run(body).await
152    }
153
154    async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
155    where
156        F: Send + Fn(TransactionHandle) -> Fut,
157        Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
158    {
159        let body = async {
160            let mut i = 0;
161            loop {
162                let (tx, result) = self.with_transaction(&f).await?;
163                match result {
164                    Ok(Some((room_id, data))) => {
165                        let lock = self.rooms.entry(room_id).or_default().clone();
166                        let _guard = lock.lock_owned().await;
167                        match tx.commit().await.map_err(Into::into) {
168                            Ok(()) => {
169                                return Ok(Some(RoomGuard {
170                                    data,
171                                    _guard,
172                                    _not_send: PhantomData,
173                                }));
174                            }
175                            Err(error) => {
176                                if !self.retry_on_serialization_error(&error, i).await {
177                                    return Err(error);
178                                }
179                            }
180                        }
181                    }
182                    Ok(None) => match tx.commit().await.map_err(Into::into) {
183                        Ok(()) => return Ok(None),
184                        Err(error) => {
185                            if !self.retry_on_serialization_error(&error, i).await {
186                                return Err(error);
187                            }
188                        }
189                    },
190                    Err(error) => {
191                        tx.rollback().await?;
192                        if !self.retry_on_serialization_error(&error, i).await {
193                            return Err(error);
194                        }
195                    }
196                }
197                i += 1;
198            }
199        };
200
201        self.run(body).await
202    }
203
204    async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
205    where
206        F: Send + Fn(TransactionHandle) -> Fut,
207        Fut: Send + Future<Output = Result<T>>,
208    {
209        let body = async {
210            let mut i = 0;
211            loop {
212                let lock = self.rooms.entry(room_id).or_default().clone();
213                let _guard = lock.lock_owned().await;
214                let (tx, result) = self.with_transaction(&f).await?;
215                match result {
216                    Ok(data) => match tx.commit().await.map_err(Into::into) {
217                        Ok(()) => {
218                            return Ok(RoomGuard {
219                                data,
220                                _guard,
221                                _not_send: PhantomData,
222                            });
223                        }
224                        Err(error) => {
225                            if !self.retry_on_serialization_error(&error, i).await {
226                                return Err(error);
227                            }
228                        }
229                    },
230                    Err(error) => {
231                        tx.rollback().await?;
232                        if !self.retry_on_serialization_error(&error, i).await {
233                            return Err(error);
234                        }
235                    }
236                }
237                i += 1;
238            }
239        };
240
241        self.run(body).await
242    }
243
244    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
245    where
246        F: Send + Fn(TransactionHandle) -> Fut,
247        Fut: Send + Future<Output = Result<T>>,
248    {
249        let tx = self
250            .pool
251            .begin_with_config(Some(IsolationLevel::Serializable), None)
252            .await?;
253
254        let mut tx = Arc::new(Some(tx));
255        let result = f(TransactionHandle(tx.clone())).await;
256        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
257            return Err(anyhow!(
258                "couldn't complete transaction because it's still in use"
259            ))?;
260        };
261
262        Ok((tx, result))
263    }
264
265    async fn run<F, T>(&self, future: F) -> Result<T>
266    where
267        F: Future<Output = Result<T>>,
268    {
269        #[cfg(test)]
270        {
271            if let Executor::Deterministic(executor) = &self.executor {
272                executor.simulate_random_delay().await;
273            }
274
275            self.runtime.as_ref().unwrap().block_on(future)
276        }
277
278        #[cfg(not(test))]
279        {
280            future.await
281        }
282    }
283
284    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
285        // If the error is due to a failure to serialize concurrent transactions, then retry
286        // this transaction after a delay. With each subsequent retry, double the delay duration.
287        // Also vary the delay randomly in order to ensure different database connections retry
288        // at different times.
289        if is_serialization_error(error) {
290            let base_delay = 4_u64 << prev_attempt_count.min(16);
291            let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
292            log::info!(
293                "retrying transaction after serialization error. delay: {} ms.",
294                randomized_delay
295            );
296            self.executor
297                .sleep(Duration::from_millis(randomized_delay as u64))
298                .await;
299            true
300        } else {
301            false
302        }
303    }
304}
305
306fn is_serialization_error(error: &Error) -> bool {
307    const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
308    match error {
309        Error::Database(
310            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
311            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
312        ) if error
313            .as_database_error()
314            .and_then(|error| error.code())
315            .as_deref()
316            == Some(SERIALIZATION_FAILURE_CODE) =>
317        {
318            true
319        }
320        _ => false,
321    }
322}
323
324struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
325
326impl Deref for TransactionHandle {
327    type Target = DatabaseTransaction;
328
329    fn deref(&self) -> &Self::Target {
330        self.0.as_ref().as_ref().unwrap()
331    }
332}
333
334pub struct RoomGuard<T> {
335    data: T,
336    _guard: OwnedMutexGuard<()>,
337    _not_send: PhantomData<Rc<()>>,
338}
339
340impl<T> Deref for RoomGuard<T> {
341    type Target = T;
342
343    fn deref(&self) -> &T {
344        &self.data
345    }
346}
347
348impl<T> DerefMut for RoomGuard<T> {
349    fn deref_mut(&mut self) -> &mut T {
350        &mut self.data
351    }
352}
353
354impl<T> RoomGuard<T> {
355    pub fn into_inner(self) -> T {
356        self.data
357    }
358}
359
360#[derive(Clone, Debug, PartialEq, Eq)]
361pub enum Contact {
362    Accepted {
363        user_id: UserId,
364        should_notify: bool,
365        busy: bool,
366    },
367    Outgoing {
368        user_id: UserId,
369    },
370    Incoming {
371        user_id: UserId,
372        should_notify: bool,
373    },
374}
375
376impl Contact {
377    pub fn user_id(&self) -> UserId {
378        match self {
379            Contact::Accepted { user_id, .. } => *user_id,
380            Contact::Outgoing { user_id } => *user_id,
381            Contact::Incoming { user_id, .. } => *user_id,
382        }
383    }
384}
385
386#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
387pub struct Invite {
388    pub email_address: String,
389    pub email_confirmation_code: String,
390}
391
392#[derive(Clone, Debug, Deserialize)]
393pub struct NewSignup {
394    pub email_address: String,
395    pub platform_mac: bool,
396    pub platform_windows: bool,
397    pub platform_linux: bool,
398    pub editor_features: Vec<String>,
399    pub programming_languages: Vec<String>,
400    pub device_id: Option<String>,
401    pub added_to_mailing_list: bool,
402    pub created_at: Option<DateTime>,
403}
404
405#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
406pub struct WaitlistSummary {
407    pub count: i64,
408    pub linux_count: i64,
409    pub mac_count: i64,
410    pub windows_count: i64,
411    pub unknown_count: i64,
412}
413
414#[derive(Debug, Serialize, Deserialize)]
415pub struct NewUserParams {
416    pub github_login: String,
417    pub github_user_id: i32,
418    pub invite_count: i32,
419}
420
421#[derive(Debug)]
422pub struct NewUserResult {
423    pub user_id: UserId,
424    pub metrics_id: String,
425    pub inviting_user_id: Option<UserId>,
426    pub signup_device_id: Option<String>,
427}
428
429#[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)]
430pub struct Channel {
431    pub id: ChannelId,
432    pub name: String,
433}
434
435#[derive(Debug, PartialEq)]
436pub struct ChannelsForUser {
437    pub channels: ChannelGraph,
438    pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
439    pub channels_with_admin_privileges: HashSet<ChannelId>,
440}
441
442#[derive(Debug)]
443pub struct RejoinedChannelBuffer {
444    pub buffer: proto::RejoinedChannelBuffer,
445    pub old_connection_id: ConnectionId,
446}
447
448#[derive(Clone)]
449pub struct JoinRoom {
450    pub room: proto::Room,
451    pub channel_id: Option<ChannelId>,
452    pub channel_members: Vec<UserId>,
453}
454
455pub struct RejoinedRoom {
456    pub room: proto::Room,
457    pub rejoined_projects: Vec<RejoinedProject>,
458    pub reshared_projects: Vec<ResharedProject>,
459    pub channel_id: Option<ChannelId>,
460    pub channel_members: Vec<UserId>,
461}
462
463pub struct ResharedProject {
464    pub id: ProjectId,
465    pub old_connection_id: ConnectionId,
466    pub collaborators: Vec<ProjectCollaborator>,
467    pub worktrees: Vec<proto::WorktreeMetadata>,
468}
469
470pub struct RejoinedProject {
471    pub id: ProjectId,
472    pub old_connection_id: ConnectionId,
473    pub collaborators: Vec<ProjectCollaborator>,
474    pub worktrees: Vec<RejoinedWorktree>,
475    pub language_servers: Vec<proto::LanguageServer>,
476}
477
478#[derive(Debug)]
479pub struct RejoinedWorktree {
480    pub id: u64,
481    pub abs_path: String,
482    pub root_name: String,
483    pub visible: bool,
484    pub updated_entries: Vec<proto::Entry>,
485    pub removed_entries: Vec<u64>,
486    pub updated_repositories: Vec<proto::RepositoryEntry>,
487    pub removed_repositories: Vec<u64>,
488    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
489    pub settings_files: Vec<WorktreeSettingsFile>,
490    pub scan_id: u64,
491    pub completed_scan_id: u64,
492}
493
494pub struct LeftRoom {
495    pub room: proto::Room,
496    pub channel_id: Option<ChannelId>,
497    pub channel_members: Vec<UserId>,
498    pub left_projects: HashMap<ProjectId, LeftProject>,
499    pub canceled_calls_to_user_ids: Vec<UserId>,
500    pub deleted: bool,
501}
502
503pub struct RefreshedRoom {
504    pub room: proto::Room,
505    pub channel_id: Option<ChannelId>,
506    pub channel_members: Vec<UserId>,
507    pub stale_participant_user_ids: Vec<UserId>,
508    pub canceled_calls_to_user_ids: Vec<UserId>,
509}
510
511pub struct RefreshedChannelBuffer {
512    pub connection_ids: Vec<ConnectionId>,
513    pub collaborators: Vec<proto::Collaborator>,
514}
515
516pub struct Project {
517    pub collaborators: Vec<ProjectCollaborator>,
518    pub worktrees: BTreeMap<u64, Worktree>,
519    pub language_servers: Vec<proto::LanguageServer>,
520}
521
522pub struct ProjectCollaborator {
523    pub connection_id: ConnectionId,
524    pub user_id: UserId,
525    pub replica_id: ReplicaId,
526    pub is_host: bool,
527}
528
529impl ProjectCollaborator {
530    pub fn to_proto(&self) -> proto::Collaborator {
531        proto::Collaborator {
532            peer_id: Some(self.connection_id.into()),
533            replica_id: self.replica_id.0 as u32,
534            user_id: self.user_id.to_proto(),
535        }
536    }
537}
538
539#[derive(Debug)]
540pub struct LeftProject {
541    pub id: ProjectId,
542    pub host_user_id: UserId,
543    pub host_connection_id: ConnectionId,
544    pub connection_ids: Vec<ConnectionId>,
545}
546
547pub struct Worktree {
548    pub id: u64,
549    pub abs_path: String,
550    pub root_name: String,
551    pub visible: bool,
552    pub entries: Vec<proto::Entry>,
553    pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
554    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
555    pub settings_files: Vec<WorktreeSettingsFile>,
556    pub scan_id: u64,
557    pub completed_scan_id: u64,
558}
559
560#[derive(Debug)]
561pub struct WorktreeSettingsFile {
562    pub path: String,
563    pub content: String,
564}