1#[cfg(test)]
2pub mod tests;
3
4#[cfg(test)]
5pub use tests::TestDb;
6
7mod ids;
8mod queries;
9mod tables;
10
11use crate::{executor::Executor, Error, Result};
12use anyhow::anyhow;
13use collections::{BTreeMap, HashMap, HashSet};
14use dashmap::DashMap;
15use futures::StreamExt;
16use rand::{prelude::StdRng, Rng, SeedableRng};
17use rpc::{proto, ConnectionId};
18use sea_orm::{
19 entity::prelude::*, ActiveValue, Condition, ConnectionTrait, DatabaseConnection,
20 DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType,
21 QueryOrder, QuerySelect, Statement, TransactionTrait,
22};
23use sea_query::{Alias, Expr, OnConflict, Query};
24use serde::{Deserialize, Serialize};
25use sqlx::{
26 migrate::{Migrate, Migration, MigrationSource},
27 Connection,
28};
29use std::{
30 fmt::Write as _,
31 future::Future,
32 marker::PhantomData,
33 ops::{Deref, DerefMut},
34 path::Path,
35 rc::Rc,
36 sync::Arc,
37 time::Duration,
38};
39use tables::*;
40use tokio::sync::{Mutex, OwnedMutexGuard};
41
42pub use ids::*;
43pub use sea_orm::ConnectOptions;
44pub use tables::user::Model as User;
45
46pub struct Database {
47 options: ConnectOptions,
48 pool: DatabaseConnection,
49 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
50 rng: Mutex<StdRng>,
51 executor: Executor,
52 #[cfg(test)]
53 runtime: Option<tokio::runtime::Runtime>,
54}
55
56// The `Database` type has so many methods that its impl blocks are split into
57// separate files in the `queries` folder.
58impl Database {
59 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
60 Ok(Self {
61 options: options.clone(),
62 pool: sea_orm::Database::connect(options).await?,
63 rooms: DashMap::with_capacity(16384),
64 rng: Mutex::new(StdRng::seed_from_u64(0)),
65 executor,
66 #[cfg(test)]
67 runtime: None,
68 })
69 }
70
71 #[cfg(test)]
72 pub fn reset(&self) {
73 self.rooms.clear();
74 }
75
76 pub async fn migrate(
77 &self,
78 migrations_path: &Path,
79 ignore_checksum_mismatch: bool,
80 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
81 let migrations = MigrationSource::resolve(migrations_path)
82 .await
83 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
84
85 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
86
87 connection.ensure_migrations_table().await?;
88 let applied_migrations: HashMap<_, _> = connection
89 .list_applied_migrations()
90 .await?
91 .into_iter()
92 .map(|m| (m.version, m))
93 .collect();
94
95 let mut new_migrations = Vec::new();
96 for migration in migrations {
97 match applied_migrations.get(&migration.version) {
98 Some(applied_migration) => {
99 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
100 {
101 Err(anyhow!(
102 "checksum mismatch for applied migration {}",
103 migration.description
104 ))?;
105 }
106 }
107 None => {
108 let elapsed = connection.apply(&migration).await?;
109 new_migrations.push((migration, elapsed));
110 }
111 }
112 }
113
114 Ok(new_migrations)
115 }
116
117 async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
118 where
119 F: Send + Fn(TransactionHandle) -> Fut,
120 Fut: Send + Future<Output = Result<T>>,
121 {
122 let body = async {
123 let mut i = 0;
124 loop {
125 let (tx, result) = self.with_transaction(&f).await?;
126 match result {
127 Ok(result) => match tx.commit().await.map_err(Into::into) {
128 Ok(()) => return Ok(result),
129 Err(error) => {
130 if !self.retry_on_serialization_error(&error, i).await {
131 return Err(error);
132 }
133 }
134 },
135 Err(error) => {
136 tx.rollback().await?;
137 if !self.retry_on_serialization_error(&error, i).await {
138 return Err(error);
139 }
140 }
141 }
142 i += 1;
143 }
144 };
145
146 self.run(body).await
147 }
148
149 async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
150 where
151 F: Send + Fn(TransactionHandle) -> Fut,
152 Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
153 {
154 let body = async {
155 let mut i = 0;
156 loop {
157 let (tx, result) = self.with_transaction(&f).await?;
158 match result {
159 Ok(Some((room_id, data))) => {
160 let lock = self.rooms.entry(room_id).or_default().clone();
161 let _guard = lock.lock_owned().await;
162 match tx.commit().await.map_err(Into::into) {
163 Ok(()) => {
164 return Ok(Some(RoomGuard {
165 data,
166 _guard,
167 _not_send: PhantomData,
168 }));
169 }
170 Err(error) => {
171 if !self.retry_on_serialization_error(&error, i).await {
172 return Err(error);
173 }
174 }
175 }
176 }
177 Ok(None) => match tx.commit().await.map_err(Into::into) {
178 Ok(()) => return Ok(None),
179 Err(error) => {
180 if !self.retry_on_serialization_error(&error, i).await {
181 return Err(error);
182 }
183 }
184 },
185 Err(error) => {
186 tx.rollback().await?;
187 if !self.retry_on_serialization_error(&error, i).await {
188 return Err(error);
189 }
190 }
191 }
192 i += 1;
193 }
194 };
195
196 self.run(body).await
197 }
198
199 async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
200 where
201 F: Send + Fn(TransactionHandle) -> Fut,
202 Fut: Send + Future<Output = Result<T>>,
203 {
204 let body = async {
205 let mut i = 0;
206 loop {
207 let lock = self.rooms.entry(room_id).or_default().clone();
208 let _guard = lock.lock_owned().await;
209 let (tx, result) = self.with_transaction(&f).await?;
210 match result {
211 Ok(data) => match tx.commit().await.map_err(Into::into) {
212 Ok(()) => {
213 return Ok(RoomGuard {
214 data,
215 _guard,
216 _not_send: PhantomData,
217 });
218 }
219 Err(error) => {
220 if !self.retry_on_serialization_error(&error, i).await {
221 return Err(error);
222 }
223 }
224 },
225 Err(error) => {
226 tx.rollback().await?;
227 if !self.retry_on_serialization_error(&error, i).await {
228 return Err(error);
229 }
230 }
231 }
232 i += 1;
233 }
234 };
235
236 self.run(body).await
237 }
238
239 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
240 where
241 F: Send + Fn(TransactionHandle) -> Fut,
242 Fut: Send + Future<Output = Result<T>>,
243 {
244 let tx = self
245 .pool
246 .begin_with_config(Some(IsolationLevel::Serializable), None)
247 .await?;
248
249 let mut tx = Arc::new(Some(tx));
250 let result = f(TransactionHandle(tx.clone())).await;
251 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
252 return Err(anyhow!(
253 "couldn't complete transaction because it's still in use"
254 ))?;
255 };
256
257 Ok((tx, result))
258 }
259
260 async fn run<F, T>(&self, future: F) -> Result<T>
261 where
262 F: Future<Output = Result<T>>,
263 {
264 #[cfg(test)]
265 {
266 if let Executor::Deterministic(executor) = &self.executor {
267 executor.simulate_random_delay().await;
268 }
269
270 self.runtime.as_ref().unwrap().block_on(future)
271 }
272
273 #[cfg(not(test))]
274 {
275 future.await
276 }
277 }
278
279 async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
280 // If the error is due to a failure to serialize concurrent transactions, then retry
281 // this transaction after a delay. With each subsequent retry, double the delay duration.
282 // Also vary the delay randomly in order to ensure different database connections retry
283 // at different times.
284 if is_serialization_error(error) {
285 let base_delay = 4_u64 << prev_attempt_count.min(16);
286 let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
287 log::info!(
288 "retrying transaction after serialization error. delay: {} ms.",
289 randomized_delay
290 );
291 self.executor
292 .sleep(Duration::from_millis(randomized_delay as u64))
293 .await;
294 true
295 } else {
296 false
297 }
298 }
299}
300
301fn is_serialization_error(error: &Error) -> bool {
302 const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
303 match error {
304 Error::Database(
305 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
306 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
307 ) if error
308 .as_database_error()
309 .and_then(|error| error.code())
310 .as_deref()
311 == Some(SERIALIZATION_FAILURE_CODE) =>
312 {
313 true
314 }
315 _ => false,
316 }
317}
318
319struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
320
321impl Deref for TransactionHandle {
322 type Target = DatabaseTransaction;
323
324 fn deref(&self) -> &Self::Target {
325 self.0.as_ref().as_ref().unwrap()
326 }
327}
328
329pub struct RoomGuard<T> {
330 data: T,
331 _guard: OwnedMutexGuard<()>,
332 _not_send: PhantomData<Rc<()>>,
333}
334
335impl<T> Deref for RoomGuard<T> {
336 type Target = T;
337
338 fn deref(&self) -> &T {
339 &self.data
340 }
341}
342
343impl<T> DerefMut for RoomGuard<T> {
344 fn deref_mut(&mut self) -> &mut T {
345 &mut self.data
346 }
347}
348
349impl<T> RoomGuard<T> {
350 pub fn into_inner(self) -> T {
351 self.data
352 }
353}
354
355#[derive(Clone, Debug, PartialEq, Eq)]
356pub enum Contact {
357 Accepted {
358 user_id: UserId,
359 should_notify: bool,
360 busy: bool,
361 },
362 Outgoing {
363 user_id: UserId,
364 },
365 Incoming {
366 user_id: UserId,
367 should_notify: bool,
368 },
369}
370
371impl Contact {
372 pub fn user_id(&self) -> UserId {
373 match self {
374 Contact::Accepted { user_id, .. } => *user_id,
375 Contact::Outgoing { user_id } => *user_id,
376 Contact::Incoming { user_id, .. } => *user_id,
377 }
378 }
379}
380
381#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
382pub struct Invite {
383 pub email_address: String,
384 pub email_confirmation_code: String,
385}
386
387#[derive(Clone, Debug, Deserialize)]
388pub struct NewSignup {
389 pub email_address: String,
390 pub platform_mac: bool,
391 pub platform_windows: bool,
392 pub platform_linux: bool,
393 pub editor_features: Vec<String>,
394 pub programming_languages: Vec<String>,
395 pub device_id: Option<String>,
396 pub added_to_mailing_list: bool,
397 pub created_at: Option<DateTime>,
398}
399
400#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
401pub struct WaitlistSummary {
402 pub count: i64,
403 pub linux_count: i64,
404 pub mac_count: i64,
405 pub windows_count: i64,
406 pub unknown_count: i64,
407}
408
409#[derive(Debug, Serialize, Deserialize)]
410pub struct NewUserParams {
411 pub github_login: String,
412 pub github_user_id: i32,
413 pub invite_count: i32,
414}
415
416#[derive(Debug)]
417pub struct NewUserResult {
418 pub user_id: UserId,
419 pub metrics_id: String,
420 pub inviting_user_id: Option<UserId>,
421 pub signup_device_id: Option<String>,
422}
423
424#[derive(FromQueryResult, Debug, PartialEq)]
425pub struct Channel {
426 pub id: ChannelId,
427 pub name: String,
428 pub parent_id: Option<ChannelId>,
429}
430
431#[derive(Debug, PartialEq)]
432pub struct ChannelsForUser {
433 pub channels: Vec<Channel>,
434 pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
435 pub channels_with_admin_privileges: HashSet<ChannelId>,
436}
437
438#[derive(Clone)]
439pub struct JoinRoom {
440 pub room: proto::Room,
441 pub channel_id: Option<ChannelId>,
442 pub channel_members: Vec<UserId>,
443}
444
445pub struct RejoinedRoom {
446 pub room: proto::Room,
447 pub rejoined_projects: Vec<RejoinedProject>,
448 pub reshared_projects: Vec<ResharedProject>,
449 pub channel_id: Option<ChannelId>,
450 pub channel_members: Vec<UserId>,
451}
452
453pub struct ResharedProject {
454 pub id: ProjectId,
455 pub old_connection_id: ConnectionId,
456 pub collaborators: Vec<ProjectCollaborator>,
457 pub worktrees: Vec<proto::WorktreeMetadata>,
458}
459
460pub struct RejoinedProject {
461 pub id: ProjectId,
462 pub old_connection_id: ConnectionId,
463 pub collaborators: Vec<ProjectCollaborator>,
464 pub worktrees: Vec<RejoinedWorktree>,
465 pub language_servers: Vec<proto::LanguageServer>,
466}
467
468#[derive(Debug)]
469pub struct RejoinedWorktree {
470 pub id: u64,
471 pub abs_path: String,
472 pub root_name: String,
473 pub visible: bool,
474 pub updated_entries: Vec<proto::Entry>,
475 pub removed_entries: Vec<u64>,
476 pub updated_repositories: Vec<proto::RepositoryEntry>,
477 pub removed_repositories: Vec<u64>,
478 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
479 pub settings_files: Vec<WorktreeSettingsFile>,
480 pub scan_id: u64,
481 pub completed_scan_id: u64,
482}
483
484pub struct LeftRoom {
485 pub room: proto::Room,
486 pub channel_id: Option<ChannelId>,
487 pub channel_members: Vec<UserId>,
488 pub left_projects: HashMap<ProjectId, LeftProject>,
489 pub canceled_calls_to_user_ids: Vec<UserId>,
490 pub deleted: bool,
491}
492
493pub struct RefreshedRoom {
494 pub room: proto::Room,
495 pub channel_id: Option<ChannelId>,
496 pub channel_members: Vec<UserId>,
497 pub stale_participant_user_ids: Vec<UserId>,
498 pub canceled_calls_to_user_ids: Vec<UserId>,
499}
500
501pub struct Project {
502 pub collaborators: Vec<ProjectCollaborator>,
503 pub worktrees: BTreeMap<u64, Worktree>,
504 pub language_servers: Vec<proto::LanguageServer>,
505}
506
507pub struct ProjectCollaborator {
508 pub connection_id: ConnectionId,
509 pub user_id: UserId,
510 pub replica_id: ReplicaId,
511 pub is_host: bool,
512}
513
514impl ProjectCollaborator {
515 pub fn to_proto(&self) -> proto::Collaborator {
516 proto::Collaborator {
517 peer_id: Some(self.connection_id.into()),
518 replica_id: self.replica_id.0 as u32,
519 user_id: self.user_id.to_proto(),
520 }
521 }
522}
523
524#[derive(Debug)]
525pub struct LeftProject {
526 pub id: ProjectId,
527 pub host_user_id: UserId,
528 pub host_connection_id: ConnectionId,
529 pub connection_ids: Vec<ConnectionId>,
530}
531
532pub struct Worktree {
533 pub id: u64,
534 pub abs_path: String,
535 pub root_name: String,
536 pub visible: bool,
537 pub entries: Vec<proto::Entry>,
538 pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
539 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
540 pub settings_files: Vec<WorktreeSettingsFile>,
541 pub scan_id: u64,
542 pub completed_scan_id: u64,
543}
544
545#[derive(Debug)]
546pub struct WorktreeSettingsFile {
547 pub path: String,
548 pub content: String,
549}