1#[cfg(test)]
2pub mod tests;
3
4#[cfg(test)]
5pub use tests::TestDb;
6
7mod ids;
8mod queries;
9mod tables;
10
11use crate::{executor::Executor, Error, Result};
12use anyhow::anyhow;
13use collections::{BTreeMap, HashMap, HashSet};
14use dashmap::DashMap;
15use futures::StreamExt;
16use rand::{prelude::StdRng, Rng, SeedableRng};
17use rpc::{proto, ConnectionId};
18use sea_orm::{
19 entity::prelude::*, ActiveValue, Condition, ConnectionTrait, DatabaseConnection,
20 DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType,
21 QueryOrder, QuerySelect, Statement, TransactionTrait,
22};
23use sea_query::{Alias, Expr, OnConflict, Query};
24use serde::{Deserialize, Serialize};
25use sqlx::{
26 migrate::{Migrate, Migration, MigrationSource},
27 Connection,
28};
29use std::{
30 fmt::Write as _,
31 future::Future,
32 marker::PhantomData,
33 ops::{Deref, DerefMut},
34 path::Path,
35 rc::Rc,
36 sync::Arc,
37 time::Duration,
38};
39use tables::*;
40use tokio::sync::{Mutex, OwnedMutexGuard};
41
42pub use ids::*;
43pub use sea_orm::ConnectOptions;
44pub use tables::user::Model as User;
45
46pub struct Database {
47 options: ConnectOptions,
48 pool: DatabaseConnection,
49 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
50 rng: Mutex<StdRng>,
51 executor: Executor,
52 #[cfg(test)]
53 runtime: Option<tokio::runtime::Runtime>,
54}
55
56// The `Database` type has so many methods that its impl blocks are split into
57// separate files in the `queries` folder.
58impl Database {
59 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
60 Ok(Self {
61 options: options.clone(),
62 pool: sea_orm::Database::connect(options).await?,
63 rooms: DashMap::with_capacity(16384),
64 rng: Mutex::new(StdRng::seed_from_u64(0)),
65 executor,
66 #[cfg(test)]
67 runtime: None,
68 })
69 }
70
71 #[cfg(test)]
72 pub fn reset(&self) {
73 self.rooms.clear();
74 }
75
76 pub async fn migrate(
77 &self,
78 migrations_path: &Path,
79 ignore_checksum_mismatch: bool,
80 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
81 let migrations = MigrationSource::resolve(migrations_path)
82 .await
83 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
84
85 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
86
87 connection.ensure_migrations_table().await?;
88 let applied_migrations: HashMap<_, _> = connection
89 .list_applied_migrations()
90 .await?
91 .into_iter()
92 .map(|m| (m.version, m))
93 .collect();
94
95 let mut new_migrations = Vec::new();
96 for migration in migrations {
97 match applied_migrations.get(&migration.version) {
98 Some(applied_migration) => {
99 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
100 {
101 Err(anyhow!(
102 "checksum mismatch for applied migration {}",
103 migration.description
104 ))?;
105 }
106 }
107 None => {
108 let elapsed = connection.apply(&migration).await?;
109 new_migrations.push((migration, elapsed));
110 }
111 }
112 }
113
114 Ok(new_migrations)
115 }
116
117 async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
118 where
119 F: Send + Fn(TransactionHandle) -> Fut,
120 Fut: Send + Future<Output = Result<T>>,
121 {
122 let body = async {
123 let mut i = 0;
124 loop {
125 let (tx, result) = self.with_transaction(&f).await?;
126 match result {
127 Ok(result) => match tx.commit().await.map_err(Into::into) {
128 Ok(()) => return Ok(result),
129 Err(error) => {
130 if !self.retry_on_serialization_error(&error, i).await {
131 return Err(error);
132 }
133 }
134 },
135 Err(error) => {
136 tx.rollback().await?;
137 if !self.retry_on_serialization_error(&error, i).await {
138 return Err(error);
139 }
140 }
141 }
142 i += 1;
143 }
144 };
145
146 self.run(body).await
147 }
148
149 async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
150 where
151 F: Send + Fn(TransactionHandle) -> Fut,
152 Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
153 {
154 let body = async {
155 let mut i = 0;
156 loop {
157 let (tx, result) = self.with_transaction(&f).await?;
158 match result {
159 Ok(Some((room_id, data))) => {
160 let lock = self.rooms.entry(room_id).or_default().clone();
161 let _guard = lock.lock_owned().await;
162 match tx.commit().await.map_err(Into::into) {
163 Ok(()) => {
164 return Ok(Some(RoomGuard {
165 data,
166 _guard,
167 _not_send: PhantomData,
168 }));
169 }
170 Err(error) => {
171 if !self.retry_on_serialization_error(&error, i).await {
172 return Err(error);
173 }
174 }
175 }
176 }
177 Ok(None) => match tx.commit().await.map_err(Into::into) {
178 Ok(()) => return Ok(None),
179 Err(error) => {
180 if !self.retry_on_serialization_error(&error, i).await {
181 return Err(error);
182 }
183 }
184 },
185 Err(error) => {
186 tx.rollback().await?;
187 if !self.retry_on_serialization_error(&error, i).await {
188 return Err(error);
189 }
190 }
191 }
192 i += 1;
193 }
194 };
195
196 self.run(body).await
197 }
198
199 async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
200 where
201 F: Send + Fn(TransactionHandle) -> Fut,
202 Fut: Send + Future<Output = Result<T>>,
203 {
204 let body = async {
205 let mut i = 0;
206 loop {
207 let lock = self.rooms.entry(room_id).or_default().clone();
208 let _guard = lock.lock_owned().await;
209 let (tx, result) = self.with_transaction(&f).await?;
210 match result {
211 Ok(data) => match tx.commit().await.map_err(Into::into) {
212 Ok(()) => {
213 return Ok(RoomGuard {
214 data,
215 _guard,
216 _not_send: PhantomData,
217 });
218 }
219 Err(error) => {
220 if !self.retry_on_serialization_error(&error, i).await {
221 return Err(error);
222 }
223 }
224 },
225 Err(error) => {
226 tx.rollback().await?;
227 if !self.retry_on_serialization_error(&error, i).await {
228 return Err(error);
229 }
230 }
231 }
232 i += 1;
233 }
234 };
235
236 self.run(body).await
237 }
238
239 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
240 where
241 F: Send + Fn(TransactionHandle) -> Fut,
242 Fut: Send + Future<Output = Result<T>>,
243 {
244 let tx = self
245 .pool
246 .begin_with_config(Some(IsolationLevel::Serializable), None)
247 .await?;
248
249 let mut tx = Arc::new(Some(tx));
250 let result = f(TransactionHandle(tx.clone())).await;
251 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
252 return Err(anyhow!("couldn't complete transaction because it's still in use"))?;
253 };
254
255 Ok((tx, result))
256 }
257
258 async fn run<F, T>(&self, future: F) -> Result<T>
259 where
260 F: Future<Output = Result<T>>,
261 {
262 #[cfg(test)]
263 {
264 if let Executor::Deterministic(executor) = &self.executor {
265 executor.simulate_random_delay().await;
266 }
267
268 self.runtime.as_ref().unwrap().block_on(future)
269 }
270
271 #[cfg(not(test))]
272 {
273 future.await
274 }
275 }
276
277 async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
278 // If the error is due to a failure to serialize concurrent transactions, then retry
279 // this transaction after a delay. With each subsequent retry, double the delay duration.
280 // Also vary the delay randomly in order to ensure different database connections retry
281 // at different times.
282 if is_serialization_error(error) {
283 let base_delay = 4_u64 << prev_attempt_count.min(16);
284 let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
285 log::info!(
286 "retrying transaction after serialization error. delay: {} ms.",
287 randomized_delay
288 );
289 self.executor
290 .sleep(Duration::from_millis(randomized_delay as u64))
291 .await;
292 true
293 } else {
294 false
295 }
296 }
297}
298
299fn is_serialization_error(error: &Error) -> bool {
300 const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
301 match error {
302 Error::Database(
303 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
304 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
305 ) if error
306 .as_database_error()
307 .and_then(|error| error.code())
308 .as_deref()
309 == Some(SERIALIZATION_FAILURE_CODE) =>
310 {
311 true
312 }
313 _ => false,
314 }
315}
316
317struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
318
319impl Deref for TransactionHandle {
320 type Target = DatabaseTransaction;
321
322 fn deref(&self) -> &Self::Target {
323 self.0.as_ref().as_ref().unwrap()
324 }
325}
326
327pub struct RoomGuard<T> {
328 data: T,
329 _guard: OwnedMutexGuard<()>,
330 _not_send: PhantomData<Rc<()>>,
331}
332
333impl<T> Deref for RoomGuard<T> {
334 type Target = T;
335
336 fn deref(&self) -> &T {
337 &self.data
338 }
339}
340
341impl<T> DerefMut for RoomGuard<T> {
342 fn deref_mut(&mut self) -> &mut T {
343 &mut self.data
344 }
345}
346
347impl<T> RoomGuard<T> {
348 pub fn into_inner(self) -> T {
349 self.data
350 }
351}
352
353#[derive(Clone, Debug, PartialEq, Eq)]
354pub enum Contact {
355 Accepted {
356 user_id: UserId,
357 should_notify: bool,
358 busy: bool,
359 },
360 Outgoing {
361 user_id: UserId,
362 },
363 Incoming {
364 user_id: UserId,
365 should_notify: bool,
366 },
367}
368
369impl Contact {
370 pub fn user_id(&self) -> UserId {
371 match self {
372 Contact::Accepted { user_id, .. } => *user_id,
373 Contact::Outgoing { user_id } => *user_id,
374 Contact::Incoming { user_id, .. } => *user_id,
375 }
376 }
377}
378
379#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
380pub struct Invite {
381 pub email_address: String,
382 pub email_confirmation_code: String,
383}
384
385#[derive(Clone, Debug, Deserialize)]
386pub struct NewSignup {
387 pub email_address: String,
388 pub platform_mac: bool,
389 pub platform_windows: bool,
390 pub platform_linux: bool,
391 pub editor_features: Vec<String>,
392 pub programming_languages: Vec<String>,
393 pub device_id: Option<String>,
394 pub added_to_mailing_list: bool,
395 pub created_at: Option<DateTime>,
396}
397
398#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
399pub struct WaitlistSummary {
400 pub count: i64,
401 pub linux_count: i64,
402 pub mac_count: i64,
403 pub windows_count: i64,
404 pub unknown_count: i64,
405}
406
407#[derive(Debug, Serialize, Deserialize)]
408pub struct NewUserParams {
409 pub github_login: String,
410 pub github_user_id: i32,
411 pub invite_count: i32,
412}
413
414#[derive(Debug)]
415pub struct NewUserResult {
416 pub user_id: UserId,
417 pub metrics_id: String,
418 pub inviting_user_id: Option<UserId>,
419 pub signup_device_id: Option<String>,
420}
421
422#[derive(FromQueryResult, Debug, PartialEq)]
423pub struct Channel {
424 pub id: ChannelId,
425 pub name: String,
426 pub parent_id: Option<ChannelId>,
427}
428
429#[derive(Debug, PartialEq)]
430pub struct ChannelsForUser {
431 pub channels: Vec<Channel>,
432 pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
433 pub channels_with_admin_privileges: HashSet<ChannelId>,
434}
435
436#[derive(Clone)]
437pub struct JoinRoom {
438 pub room: proto::Room,
439 pub channel_id: Option<ChannelId>,
440 pub channel_members: Vec<UserId>,
441}
442
443pub struct RejoinedRoom {
444 pub room: proto::Room,
445 pub rejoined_projects: Vec<RejoinedProject>,
446 pub reshared_projects: Vec<ResharedProject>,
447 pub channel_id: Option<ChannelId>,
448 pub channel_members: Vec<UserId>,
449}
450
451pub struct ResharedProject {
452 pub id: ProjectId,
453 pub old_connection_id: ConnectionId,
454 pub collaborators: Vec<ProjectCollaborator>,
455 pub worktrees: Vec<proto::WorktreeMetadata>,
456}
457
458pub struct RejoinedProject {
459 pub id: ProjectId,
460 pub old_connection_id: ConnectionId,
461 pub collaborators: Vec<ProjectCollaborator>,
462 pub worktrees: Vec<RejoinedWorktree>,
463 pub language_servers: Vec<proto::LanguageServer>,
464}
465
466#[derive(Debug)]
467pub struct RejoinedWorktree {
468 pub id: u64,
469 pub abs_path: String,
470 pub root_name: String,
471 pub visible: bool,
472 pub updated_entries: Vec<proto::Entry>,
473 pub removed_entries: Vec<u64>,
474 pub updated_repositories: Vec<proto::RepositoryEntry>,
475 pub removed_repositories: Vec<u64>,
476 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
477 pub settings_files: Vec<WorktreeSettingsFile>,
478 pub scan_id: u64,
479 pub completed_scan_id: u64,
480}
481
482pub struct LeftRoom {
483 pub room: proto::Room,
484 pub channel_id: Option<ChannelId>,
485 pub channel_members: Vec<UserId>,
486 pub left_projects: HashMap<ProjectId, LeftProject>,
487 pub canceled_calls_to_user_ids: Vec<UserId>,
488 pub deleted: bool,
489}
490
491pub struct RefreshedRoom {
492 pub room: proto::Room,
493 pub channel_id: Option<ChannelId>,
494 pub channel_members: Vec<UserId>,
495 pub stale_participant_user_ids: Vec<UserId>,
496 pub canceled_calls_to_user_ids: Vec<UserId>,
497}
498
499pub struct Project {
500 pub collaborators: Vec<ProjectCollaborator>,
501 pub worktrees: BTreeMap<u64, Worktree>,
502 pub language_servers: Vec<proto::LanguageServer>,
503}
504
505pub struct ProjectCollaborator {
506 pub connection_id: ConnectionId,
507 pub user_id: UserId,
508 pub replica_id: ReplicaId,
509 pub is_host: bool,
510}
511
512impl ProjectCollaborator {
513 pub fn to_proto(&self) -> proto::Collaborator {
514 proto::Collaborator {
515 peer_id: Some(self.connection_id.into()),
516 replica_id: self.replica_id.0 as u32,
517 user_id: self.user_id.to_proto(),
518 }
519 }
520}
521
522#[derive(Debug)]
523pub struct LeftProject {
524 pub id: ProjectId,
525 pub host_user_id: UserId,
526 pub host_connection_id: ConnectionId,
527 pub connection_ids: Vec<ConnectionId>,
528}
529
530pub struct Worktree {
531 pub id: u64,
532 pub abs_path: String,
533 pub root_name: String,
534 pub visible: bool,
535 pub entries: Vec<proto::Entry>,
536 pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
537 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
538 pub settings_files: Vec<WorktreeSettingsFile>,
539 pub scan_id: u64,
540 pub completed_scan_id: u64,
541}
542
543#[derive(Debug)]
544pub struct WorktreeSettingsFile {
545 pub path: String,
546 pub content: String,
547}