1#[cfg(test)]
2mod db_tests;
3#[cfg(test)]
4pub mod test_db;
5
6mod ids;
7mod queries;
8mod tables;
9
10use crate::{executor::Executor, Error, Result};
11use anyhow::anyhow;
12use collections::{BTreeMap, HashMap, HashSet};
13use dashmap::DashMap;
14use futures::StreamExt;
15use rand::{prelude::StdRng, Rng, SeedableRng};
16use rpc::{proto, ConnectionId};
17use sea_orm::{
18 entity::prelude::*, ActiveValue, Condition, ConnectionTrait, DatabaseConnection,
19 DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType,
20 QueryOrder, QuerySelect, Statement, TransactionTrait,
21};
22use sea_query::{Alias, Expr, OnConflict, Query};
23use serde::{Deserialize, Serialize};
24use sqlx::{
25 migrate::{Migrate, Migration, MigrationSource},
26 Connection,
27};
28use std::{
29 fmt::Write as _,
30 future::Future,
31 marker::PhantomData,
32 ops::{Deref, DerefMut},
33 path::Path,
34 rc::Rc,
35 sync::Arc,
36 time::Duration,
37};
38use tables::*;
39use tokio::sync::{Mutex, OwnedMutexGuard};
40
41pub use ids::*;
42pub use sea_orm::ConnectOptions;
43pub use tables::user::Model as User;
44
45pub struct Database {
46 options: ConnectOptions,
47 pool: DatabaseConnection,
48 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
49 rng: Mutex<StdRng>,
50 executor: Executor,
51 #[cfg(test)]
52 runtime: Option<tokio::runtime::Runtime>,
53}
54
55// The `Database` type has so many methods that its impl blocks are split into
56// separate files in the `queries` folder.
57impl Database {
58 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
59 Ok(Self {
60 options: options.clone(),
61 pool: sea_orm::Database::connect(options).await?,
62 rooms: DashMap::with_capacity(16384),
63 rng: Mutex::new(StdRng::seed_from_u64(0)),
64 executor,
65 #[cfg(test)]
66 runtime: None,
67 })
68 }
69
70 #[cfg(test)]
71 pub fn reset(&self) {
72 self.rooms.clear();
73 }
74
75 pub async fn migrate(
76 &self,
77 migrations_path: &Path,
78 ignore_checksum_mismatch: bool,
79 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
80 let migrations = MigrationSource::resolve(migrations_path)
81 .await
82 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
83
84 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
85
86 connection.ensure_migrations_table().await?;
87 let applied_migrations: HashMap<_, _> = connection
88 .list_applied_migrations()
89 .await?
90 .into_iter()
91 .map(|m| (m.version, m))
92 .collect();
93
94 let mut new_migrations = Vec::new();
95 for migration in migrations {
96 match applied_migrations.get(&migration.version) {
97 Some(applied_migration) => {
98 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
99 {
100 Err(anyhow!(
101 "checksum mismatch for applied migration {}",
102 migration.description
103 ))?;
104 }
105 }
106 None => {
107 let elapsed = connection.apply(&migration).await?;
108 new_migrations.push((migration, elapsed));
109 }
110 }
111 }
112
113 Ok(new_migrations)
114 }
115
116 async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
117 where
118 F: Send + Fn(TransactionHandle) -> Fut,
119 Fut: Send + Future<Output = Result<T>>,
120 {
121 let body = async {
122 let mut i = 0;
123 loop {
124 let (tx, result) = self.with_transaction(&f).await?;
125 match result {
126 Ok(result) => match tx.commit().await.map_err(Into::into) {
127 Ok(()) => return Ok(result),
128 Err(error) => {
129 if !self.retry_on_serialization_error(&error, i).await {
130 return Err(error);
131 }
132 }
133 },
134 Err(error) => {
135 tx.rollback().await?;
136 if !self.retry_on_serialization_error(&error, i).await {
137 return Err(error);
138 }
139 }
140 }
141 i += 1;
142 }
143 };
144
145 self.run(body).await
146 }
147
148 async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
149 where
150 F: Send + Fn(TransactionHandle) -> Fut,
151 Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
152 {
153 let body = async {
154 let mut i = 0;
155 loop {
156 let (tx, result) = self.with_transaction(&f).await?;
157 match result {
158 Ok(Some((room_id, data))) => {
159 let lock = self.rooms.entry(room_id).or_default().clone();
160 let _guard = lock.lock_owned().await;
161 match tx.commit().await.map_err(Into::into) {
162 Ok(()) => {
163 return Ok(Some(RoomGuard {
164 data,
165 _guard,
166 _not_send: PhantomData,
167 }));
168 }
169 Err(error) => {
170 if !self.retry_on_serialization_error(&error, i).await {
171 return Err(error);
172 }
173 }
174 }
175 }
176 Ok(None) => match tx.commit().await.map_err(Into::into) {
177 Ok(()) => return Ok(None),
178 Err(error) => {
179 if !self.retry_on_serialization_error(&error, i).await {
180 return Err(error);
181 }
182 }
183 },
184 Err(error) => {
185 tx.rollback().await?;
186 if !self.retry_on_serialization_error(&error, i).await {
187 return Err(error);
188 }
189 }
190 }
191 i += 1;
192 }
193 };
194
195 self.run(body).await
196 }
197
198 async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
199 where
200 F: Send + Fn(TransactionHandle) -> Fut,
201 Fut: Send + Future<Output = Result<T>>,
202 {
203 let body = async {
204 let mut i = 0;
205 loop {
206 let lock = self.rooms.entry(room_id).or_default().clone();
207 let _guard = lock.lock_owned().await;
208 let (tx, result) = self.with_transaction(&f).await?;
209 match result {
210 Ok(data) => match tx.commit().await.map_err(Into::into) {
211 Ok(()) => {
212 return Ok(RoomGuard {
213 data,
214 _guard,
215 _not_send: PhantomData,
216 });
217 }
218 Err(error) => {
219 if !self.retry_on_serialization_error(&error, i).await {
220 return Err(error);
221 }
222 }
223 },
224 Err(error) => {
225 tx.rollback().await?;
226 if !self.retry_on_serialization_error(&error, i).await {
227 return Err(error);
228 }
229 }
230 }
231 i += 1;
232 }
233 };
234
235 self.run(body).await
236 }
237
238 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
239 where
240 F: Send + Fn(TransactionHandle) -> Fut,
241 Fut: Send + Future<Output = Result<T>>,
242 {
243 let tx = self
244 .pool
245 .begin_with_config(Some(IsolationLevel::Serializable), None)
246 .await?;
247
248 let mut tx = Arc::new(Some(tx));
249 let result = f(TransactionHandle(tx.clone())).await;
250 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
251 return Err(anyhow!("couldn't complete transaction because it's still in use"))?;
252 };
253
254 Ok((tx, result))
255 }
256
257 async fn run<F, T>(&self, future: F) -> Result<T>
258 where
259 F: Future<Output = Result<T>>,
260 {
261 #[cfg(test)]
262 {
263 if let Executor::Deterministic(executor) = &self.executor {
264 executor.simulate_random_delay().await;
265 }
266
267 self.runtime.as_ref().unwrap().block_on(future)
268 }
269
270 #[cfg(not(test))]
271 {
272 future.await
273 }
274 }
275
276 async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
277 // If the error is due to a failure to serialize concurrent transactions, then retry
278 // this transaction after a delay. With each subsequent retry, double the delay duration.
279 // Also vary the delay randomly in order to ensure different database connections retry
280 // at different times.
281 if is_serialization_error(error) {
282 let base_delay = 4_u64 << prev_attempt_count.min(16);
283 let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
284 log::info!(
285 "retrying transaction after serialization error. delay: {} ms.",
286 randomized_delay
287 );
288 self.executor
289 .sleep(Duration::from_millis(randomized_delay as u64))
290 .await;
291 true
292 } else {
293 false
294 }
295 }
296}
297
298fn is_serialization_error(error: &Error) -> bool {
299 const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
300 match error {
301 Error::Database(
302 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
303 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
304 ) if error
305 .as_database_error()
306 .and_then(|error| error.code())
307 .as_deref()
308 == Some(SERIALIZATION_FAILURE_CODE) =>
309 {
310 true
311 }
312 _ => false,
313 }
314}
315
316struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
317
318impl Deref for TransactionHandle {
319 type Target = DatabaseTransaction;
320
321 fn deref(&self) -> &Self::Target {
322 self.0.as_ref().as_ref().unwrap()
323 }
324}
325
326pub struct RoomGuard<T> {
327 data: T,
328 _guard: OwnedMutexGuard<()>,
329 _not_send: PhantomData<Rc<()>>,
330}
331
332impl<T> Deref for RoomGuard<T> {
333 type Target = T;
334
335 fn deref(&self) -> &T {
336 &self.data
337 }
338}
339
340impl<T> DerefMut for RoomGuard<T> {
341 fn deref_mut(&mut self) -> &mut T {
342 &mut self.data
343 }
344}
345
346impl<T> RoomGuard<T> {
347 pub fn into_inner(self) -> T {
348 self.data
349 }
350}
351
352#[derive(Clone, Debug, PartialEq, Eq)]
353pub enum Contact {
354 Accepted {
355 user_id: UserId,
356 should_notify: bool,
357 busy: bool,
358 },
359 Outgoing {
360 user_id: UserId,
361 },
362 Incoming {
363 user_id: UserId,
364 should_notify: bool,
365 },
366}
367
368impl Contact {
369 pub fn user_id(&self) -> UserId {
370 match self {
371 Contact::Accepted { user_id, .. } => *user_id,
372 Contact::Outgoing { user_id } => *user_id,
373 Contact::Incoming { user_id, .. } => *user_id,
374 }
375 }
376}
377
378#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
379pub struct Invite {
380 pub email_address: String,
381 pub email_confirmation_code: String,
382}
383
384#[derive(Clone, Debug, Deserialize)]
385pub struct NewSignup {
386 pub email_address: String,
387 pub platform_mac: bool,
388 pub platform_windows: bool,
389 pub platform_linux: bool,
390 pub editor_features: Vec<String>,
391 pub programming_languages: Vec<String>,
392 pub device_id: Option<String>,
393 pub added_to_mailing_list: bool,
394 pub created_at: Option<DateTime>,
395}
396
397#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
398pub struct WaitlistSummary {
399 pub count: i64,
400 pub linux_count: i64,
401 pub mac_count: i64,
402 pub windows_count: i64,
403 pub unknown_count: i64,
404}
405
406#[derive(Debug, Serialize, Deserialize)]
407pub struct NewUserParams {
408 pub github_login: String,
409 pub github_user_id: i32,
410 pub invite_count: i32,
411}
412
413#[derive(Debug)]
414pub struct NewUserResult {
415 pub user_id: UserId,
416 pub metrics_id: String,
417 pub inviting_user_id: Option<UserId>,
418 pub signup_device_id: Option<String>,
419}
420
421#[derive(FromQueryResult, Debug, PartialEq)]
422pub struct Channel {
423 pub id: ChannelId,
424 pub name: String,
425 pub parent_id: Option<ChannelId>,
426}
427
428#[derive(Debug, PartialEq)]
429pub struct ChannelsForUser {
430 pub channels: Vec<Channel>,
431 pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
432 pub channels_with_admin_privileges: HashSet<ChannelId>,
433}
434
435#[derive(Clone)]
436pub struct JoinRoom {
437 pub room: proto::Room,
438 pub channel_id: Option<ChannelId>,
439 pub channel_members: Vec<UserId>,
440}
441
442pub struct RejoinedRoom {
443 pub room: proto::Room,
444 pub rejoined_projects: Vec<RejoinedProject>,
445 pub reshared_projects: Vec<ResharedProject>,
446 pub channel_id: Option<ChannelId>,
447 pub channel_members: Vec<UserId>,
448}
449
450pub struct ResharedProject {
451 pub id: ProjectId,
452 pub old_connection_id: ConnectionId,
453 pub collaborators: Vec<ProjectCollaborator>,
454 pub worktrees: Vec<proto::WorktreeMetadata>,
455}
456
457pub struct RejoinedProject {
458 pub id: ProjectId,
459 pub old_connection_id: ConnectionId,
460 pub collaborators: Vec<ProjectCollaborator>,
461 pub worktrees: Vec<RejoinedWorktree>,
462 pub language_servers: Vec<proto::LanguageServer>,
463}
464
465#[derive(Debug)]
466pub struct RejoinedWorktree {
467 pub id: u64,
468 pub abs_path: String,
469 pub root_name: String,
470 pub visible: bool,
471 pub updated_entries: Vec<proto::Entry>,
472 pub removed_entries: Vec<u64>,
473 pub updated_repositories: Vec<proto::RepositoryEntry>,
474 pub removed_repositories: Vec<u64>,
475 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
476 pub settings_files: Vec<WorktreeSettingsFile>,
477 pub scan_id: u64,
478 pub completed_scan_id: u64,
479}
480
481pub struct LeftRoom {
482 pub room: proto::Room,
483 pub channel_id: Option<ChannelId>,
484 pub channel_members: Vec<UserId>,
485 pub left_projects: HashMap<ProjectId, LeftProject>,
486 pub canceled_calls_to_user_ids: Vec<UserId>,
487 pub deleted: bool,
488}
489
490pub struct RefreshedRoom {
491 pub room: proto::Room,
492 pub channel_id: Option<ChannelId>,
493 pub channel_members: Vec<UserId>,
494 pub stale_participant_user_ids: Vec<UserId>,
495 pub canceled_calls_to_user_ids: Vec<UserId>,
496}
497
498pub struct Project {
499 pub collaborators: Vec<ProjectCollaborator>,
500 pub worktrees: BTreeMap<u64, Worktree>,
501 pub language_servers: Vec<proto::LanguageServer>,
502}
503
504pub struct ProjectCollaborator {
505 pub connection_id: ConnectionId,
506 pub user_id: UserId,
507 pub replica_id: ReplicaId,
508 pub is_host: bool,
509}
510
511impl ProjectCollaborator {
512 pub fn to_proto(&self) -> proto::Collaborator {
513 proto::Collaborator {
514 peer_id: Some(self.connection_id.into()),
515 replica_id: self.replica_id.0 as u32,
516 user_id: self.user_id.to_proto(),
517 }
518 }
519}
520
521#[derive(Debug)]
522pub struct LeftProject {
523 pub id: ProjectId,
524 pub host_user_id: UserId,
525 pub host_connection_id: ConnectionId,
526 pub connection_ids: Vec<ConnectionId>,
527}
528
529pub struct Worktree {
530 pub id: u64,
531 pub abs_path: String,
532 pub root_name: String,
533 pub visible: bool,
534 pub entries: Vec<proto::Entry>,
535 pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
536 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
537 pub settings_files: Vec<WorktreeSettingsFile>,
538 pub scan_id: u64,
539 pub completed_scan_id: u64,
540}
541
542#[derive(Debug)]
543pub struct WorktreeSettingsFile {
544 pub path: String,
545 pub content: String,
546}