1#[cfg(test)]
2mod db_tests;
3#[cfg(test)]
4pub mod test_db;
5
6mod ids;
7mod queries;
8mod tables;
9
10use crate::{executor::Executor, Error, Result};
11use anyhow::anyhow;
12use collections::{BTreeMap, HashMap, HashSet};
13use dashmap::DashMap;
14use futures::StreamExt;
15use rand::{prelude::StdRng, Rng, SeedableRng};
16use rpc::{proto, ConnectionId};
17use sea_orm::{
18 entity::prelude::*, ActiveValue, Condition, ConnectionTrait, DatabaseConnection,
19 DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType,
20 QueryOrder, QuerySelect, Statement, TransactionTrait,
21};
22use sea_query::{Alias, Expr, OnConflict, Query};
23use serde::{Deserialize, Serialize};
24use sqlx::{
25 migrate::{Migrate, Migration, MigrationSource},
26 Connection,
27};
28use std::{
29 fmt::Write as _,
30 future::Future,
31 marker::PhantomData,
32 ops::{Deref, DerefMut},
33 path::Path,
34 rc::Rc,
35 sync::Arc,
36 time::Duration,
37};
38use tables::*;
39use tokio::sync::{Mutex, OwnedMutexGuard};
40
41pub use ids::*;
42pub use sea_orm::ConnectOptions;
43pub use tables::user::Model as User;
44
45pub struct Database {
46 options: ConnectOptions,
47 pool: DatabaseConnection,
48 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
49 rng: Mutex<StdRng>,
50 executor: Executor,
51 #[cfg(test)]
52 runtime: Option<tokio::runtime::Runtime>,
53}
54
55impl Database {
56 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
57 Ok(Self {
58 options: options.clone(),
59 pool: sea_orm::Database::connect(options).await?,
60 rooms: DashMap::with_capacity(16384),
61 rng: Mutex::new(StdRng::seed_from_u64(0)),
62 executor,
63 #[cfg(test)]
64 runtime: None,
65 })
66 }
67
68 #[cfg(test)]
69 pub fn reset(&self) {
70 self.rooms.clear();
71 }
72
73 pub async fn migrate(
74 &self,
75 migrations_path: &Path,
76 ignore_checksum_mismatch: bool,
77 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
78 let migrations = MigrationSource::resolve(migrations_path)
79 .await
80 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
81
82 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
83
84 connection.ensure_migrations_table().await?;
85 let applied_migrations: HashMap<_, _> = connection
86 .list_applied_migrations()
87 .await?
88 .into_iter()
89 .map(|m| (m.version, m))
90 .collect();
91
92 let mut new_migrations = Vec::new();
93 for migration in migrations {
94 match applied_migrations.get(&migration.version) {
95 Some(applied_migration) => {
96 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
97 {
98 Err(anyhow!(
99 "checksum mismatch for applied migration {}",
100 migration.description
101 ))?;
102 }
103 }
104 None => {
105 let elapsed = connection.apply(&migration).await?;
106 new_migrations.push((migration, elapsed));
107 }
108 }
109 }
110
111 Ok(new_migrations)
112 }
113
114 async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
115 where
116 F: Send + Fn(TransactionHandle) -> Fut,
117 Fut: Send + Future<Output = Result<T>>,
118 {
119 let body = async {
120 let mut i = 0;
121 loop {
122 let (tx, result) = self.with_transaction(&f).await?;
123 match result {
124 Ok(result) => match tx.commit().await.map_err(Into::into) {
125 Ok(()) => return Ok(result),
126 Err(error) => {
127 if !self.retry_on_serialization_error(&error, i).await {
128 return Err(error);
129 }
130 }
131 },
132 Err(error) => {
133 tx.rollback().await?;
134 if !self.retry_on_serialization_error(&error, i).await {
135 return Err(error);
136 }
137 }
138 }
139 i += 1;
140 }
141 };
142
143 self.run(body).await
144 }
145
146 async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
147 where
148 F: Send + Fn(TransactionHandle) -> Fut,
149 Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
150 {
151 let body = async {
152 let mut i = 0;
153 loop {
154 let (tx, result) = self.with_transaction(&f).await?;
155 match result {
156 Ok(Some((room_id, data))) => {
157 let lock = self.rooms.entry(room_id).or_default().clone();
158 let _guard = lock.lock_owned().await;
159 match tx.commit().await.map_err(Into::into) {
160 Ok(()) => {
161 return Ok(Some(RoomGuard {
162 data,
163 _guard,
164 _not_send: PhantomData,
165 }));
166 }
167 Err(error) => {
168 if !self.retry_on_serialization_error(&error, i).await {
169 return Err(error);
170 }
171 }
172 }
173 }
174 Ok(None) => match tx.commit().await.map_err(Into::into) {
175 Ok(()) => return Ok(None),
176 Err(error) => {
177 if !self.retry_on_serialization_error(&error, i).await {
178 return Err(error);
179 }
180 }
181 },
182 Err(error) => {
183 tx.rollback().await?;
184 if !self.retry_on_serialization_error(&error, i).await {
185 return Err(error);
186 }
187 }
188 }
189 i += 1;
190 }
191 };
192
193 self.run(body).await
194 }
195
196 async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
197 where
198 F: Send + Fn(TransactionHandle) -> Fut,
199 Fut: Send + Future<Output = Result<T>>,
200 {
201 let body = async {
202 let mut i = 0;
203 loop {
204 let lock = self.rooms.entry(room_id).or_default().clone();
205 let _guard = lock.lock_owned().await;
206 let (tx, result) = self.with_transaction(&f).await?;
207 match result {
208 Ok(data) => match tx.commit().await.map_err(Into::into) {
209 Ok(()) => {
210 return Ok(RoomGuard {
211 data,
212 _guard,
213 _not_send: PhantomData,
214 });
215 }
216 Err(error) => {
217 if !self.retry_on_serialization_error(&error, i).await {
218 return Err(error);
219 }
220 }
221 },
222 Err(error) => {
223 tx.rollback().await?;
224 if !self.retry_on_serialization_error(&error, i).await {
225 return Err(error);
226 }
227 }
228 }
229 i += 1;
230 }
231 };
232
233 self.run(body).await
234 }
235
236 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
237 where
238 F: Send + Fn(TransactionHandle) -> Fut,
239 Fut: Send + Future<Output = Result<T>>,
240 {
241 let tx = self
242 .pool
243 .begin_with_config(Some(IsolationLevel::Serializable), None)
244 .await?;
245
246 let mut tx = Arc::new(Some(tx));
247 let result = f(TransactionHandle(tx.clone())).await;
248 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
249 return Err(anyhow!("couldn't complete transaction because it's still in use"))?;
250 };
251
252 Ok((tx, result))
253 }
254
255 async fn run<F, T>(&self, future: F) -> Result<T>
256 where
257 F: Future<Output = Result<T>>,
258 {
259 #[cfg(test)]
260 {
261 if let Executor::Deterministic(executor) = &self.executor {
262 executor.simulate_random_delay().await;
263 }
264
265 self.runtime.as_ref().unwrap().block_on(future)
266 }
267
268 #[cfg(not(test))]
269 {
270 future.await
271 }
272 }
273
274 async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
275 // If the error is due to a failure to serialize concurrent transactions, then retry
276 // this transaction after a delay. With each subsequent retry, double the delay duration.
277 // Also vary the delay randomly in order to ensure different database connections retry
278 // at different times.
279 if is_serialization_error(error) {
280 let base_delay = 4_u64 << prev_attempt_count.min(16);
281 let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
282 log::info!(
283 "retrying transaction after serialization error. delay: {} ms.",
284 randomized_delay
285 );
286 self.executor
287 .sleep(Duration::from_millis(randomized_delay as u64))
288 .await;
289 true
290 } else {
291 false
292 }
293 }
294}
295
296fn is_serialization_error(error: &Error) -> bool {
297 const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
298 match error {
299 Error::Database(
300 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
301 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
302 ) if error
303 .as_database_error()
304 .and_then(|error| error.code())
305 .as_deref()
306 == Some(SERIALIZATION_FAILURE_CODE) =>
307 {
308 true
309 }
310 _ => false,
311 }
312}
313
314struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
315
316impl Deref for TransactionHandle {
317 type Target = DatabaseTransaction;
318
319 fn deref(&self) -> &Self::Target {
320 self.0.as_ref().as_ref().unwrap()
321 }
322}
323
324pub struct RoomGuard<T> {
325 data: T,
326 _guard: OwnedMutexGuard<()>,
327 _not_send: PhantomData<Rc<()>>,
328}
329
330impl<T> Deref for RoomGuard<T> {
331 type Target = T;
332
333 fn deref(&self) -> &T {
334 &self.data
335 }
336}
337
338impl<T> DerefMut for RoomGuard<T> {
339 fn deref_mut(&mut self) -> &mut T {
340 &mut self.data
341 }
342}
343
344impl<T> RoomGuard<T> {
345 pub fn into_inner(self) -> T {
346 self.data
347 }
348}
349
350#[derive(Clone, Debug, PartialEq, Eq)]
351pub enum Contact {
352 Accepted {
353 user_id: UserId,
354 should_notify: bool,
355 busy: bool,
356 },
357 Outgoing {
358 user_id: UserId,
359 },
360 Incoming {
361 user_id: UserId,
362 should_notify: bool,
363 },
364}
365
366impl Contact {
367 pub fn user_id(&self) -> UserId {
368 match self {
369 Contact::Accepted { user_id, .. } => *user_id,
370 Contact::Outgoing { user_id } => *user_id,
371 Contact::Incoming { user_id, .. } => *user_id,
372 }
373 }
374}
375
376#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
377pub struct Invite {
378 pub email_address: String,
379 pub email_confirmation_code: String,
380}
381
382#[derive(Clone, Debug, Deserialize)]
383pub struct NewSignup {
384 pub email_address: String,
385 pub platform_mac: bool,
386 pub platform_windows: bool,
387 pub platform_linux: bool,
388 pub editor_features: Vec<String>,
389 pub programming_languages: Vec<String>,
390 pub device_id: Option<String>,
391 pub added_to_mailing_list: bool,
392 pub created_at: Option<DateTime>,
393}
394
395#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
396pub struct WaitlistSummary {
397 pub count: i64,
398 pub linux_count: i64,
399 pub mac_count: i64,
400 pub windows_count: i64,
401 pub unknown_count: i64,
402}
403
404#[derive(Debug, Serialize, Deserialize)]
405pub struct NewUserParams {
406 pub github_login: String,
407 pub github_user_id: i32,
408 pub invite_count: i32,
409}
410
411#[derive(Debug)]
412pub struct NewUserResult {
413 pub user_id: UserId,
414 pub metrics_id: String,
415 pub inviting_user_id: Option<UserId>,
416 pub signup_device_id: Option<String>,
417}
418
419#[derive(FromQueryResult, Debug, PartialEq)]
420pub struct Channel {
421 pub id: ChannelId,
422 pub name: String,
423 pub parent_id: Option<ChannelId>,
424}
425
426#[derive(Debug, PartialEq)]
427pub struct ChannelsForUser {
428 pub channels: Vec<Channel>,
429 pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
430 pub channels_with_admin_privileges: HashSet<ChannelId>,
431}
432
433#[derive(Clone)]
434pub struct JoinRoom {
435 pub room: proto::Room,
436 pub channel_id: Option<ChannelId>,
437 pub channel_members: Vec<UserId>,
438}
439
440pub struct RejoinedRoom {
441 pub room: proto::Room,
442 pub rejoined_projects: Vec<RejoinedProject>,
443 pub reshared_projects: Vec<ResharedProject>,
444 pub channel_id: Option<ChannelId>,
445 pub channel_members: Vec<UserId>,
446}
447
448pub struct ResharedProject {
449 pub id: ProjectId,
450 pub old_connection_id: ConnectionId,
451 pub collaborators: Vec<ProjectCollaborator>,
452 pub worktrees: Vec<proto::WorktreeMetadata>,
453}
454
455pub struct RejoinedProject {
456 pub id: ProjectId,
457 pub old_connection_id: ConnectionId,
458 pub collaborators: Vec<ProjectCollaborator>,
459 pub worktrees: Vec<RejoinedWorktree>,
460 pub language_servers: Vec<proto::LanguageServer>,
461}
462
463#[derive(Debug)]
464pub struct RejoinedWorktree {
465 pub id: u64,
466 pub abs_path: String,
467 pub root_name: String,
468 pub visible: bool,
469 pub updated_entries: Vec<proto::Entry>,
470 pub removed_entries: Vec<u64>,
471 pub updated_repositories: Vec<proto::RepositoryEntry>,
472 pub removed_repositories: Vec<u64>,
473 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
474 pub settings_files: Vec<WorktreeSettingsFile>,
475 pub scan_id: u64,
476 pub completed_scan_id: u64,
477}
478
479pub struct LeftRoom {
480 pub room: proto::Room,
481 pub channel_id: Option<ChannelId>,
482 pub channel_members: Vec<UserId>,
483 pub left_projects: HashMap<ProjectId, LeftProject>,
484 pub canceled_calls_to_user_ids: Vec<UserId>,
485 pub deleted: bool,
486}
487
488pub struct RefreshedRoom {
489 pub room: proto::Room,
490 pub channel_id: Option<ChannelId>,
491 pub channel_members: Vec<UserId>,
492 pub stale_participant_user_ids: Vec<UserId>,
493 pub canceled_calls_to_user_ids: Vec<UserId>,
494}
495
496pub struct Project {
497 pub collaborators: Vec<ProjectCollaborator>,
498 pub worktrees: BTreeMap<u64, Worktree>,
499 pub language_servers: Vec<proto::LanguageServer>,
500}
501
502pub struct ProjectCollaborator {
503 pub connection_id: ConnectionId,
504 pub user_id: UserId,
505 pub replica_id: ReplicaId,
506 pub is_host: bool,
507}
508
509impl ProjectCollaborator {
510 pub fn to_proto(&self) -> proto::Collaborator {
511 proto::Collaborator {
512 peer_id: Some(self.connection_id.into()),
513 replica_id: self.replica_id.0 as u32,
514 user_id: self.user_id.to_proto(),
515 }
516 }
517}
518
519#[derive(Debug)]
520pub struct LeftProject {
521 pub id: ProjectId,
522 pub host_user_id: UserId,
523 pub host_connection_id: ConnectionId,
524 pub connection_ids: Vec<ConnectionId>,
525}
526
527pub struct Worktree {
528 pub id: u64,
529 pub abs_path: String,
530 pub root_name: String,
531 pub visible: bool,
532 pub entries: Vec<proto::Entry>,
533 pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
534 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
535 pub settings_files: Vec<WorktreeSettingsFile>,
536 pub scan_id: u64,
537 pub completed_scan_id: u64,
538}
539
540#[derive(Debug)]
541pub struct WorktreeSettingsFile {
542 pub path: String,
543 pub content: String,
544}