1#[cfg(test)]
2pub mod tests;
3
4#[cfg(test)]
5pub use tests::TestDb;
6
7mod ids;
8mod queries;
9mod tables;
10
11use crate::{executor::Executor, Error, Result};
12use anyhow::anyhow;
13use collections::{BTreeMap, HashMap, HashSet};
14use dashmap::DashMap;
15use futures::StreamExt;
16use rand::{prelude::StdRng, Rng, SeedableRng};
17use rpc::{proto::{self}, ConnectionId};
18use sea_orm::{
19 entity::prelude::*, ActiveValue, Condition, ConnectionTrait, DatabaseConnection,
20 DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType,
21 QueryOrder, QuerySelect, Statement, TransactionTrait,
22};
23use sea_query::{Alias, Expr, OnConflict, Query};
24use serde::{Deserialize, Serialize};
25use sqlx::{
26 migrate::{Migrate, Migration, MigrationSource},
27 Connection,
28};
29use std::{
30 fmt::Write as _,
31 future::Future,
32 marker::PhantomData,
33 ops::{Deref, DerefMut},
34 path::Path,
35 rc::Rc,
36 sync::Arc,
37 time::Duration,
38};
39use tables::*;
40use tokio::sync::{Mutex, OwnedMutexGuard};
41
42pub use ids::*;
43pub use sea_orm::ConnectOptions;
44pub use tables::user::Model as User;
45
46use self::queries::channels::ChannelGraph;
47
48pub struct Database {
49 options: ConnectOptions,
50 pool: DatabaseConnection,
51 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
52 rng: Mutex<StdRng>,
53 executor: Executor,
54 #[cfg(test)]
55 runtime: Option<tokio::runtime::Runtime>,
56}
57
58// The `Database` type has so many methods that its impl blocks are split into
59// separate files in the `queries` folder.
60impl Database {
61 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
62 Ok(Self {
63 options: options.clone(),
64 pool: sea_orm::Database::connect(options).await?,
65 rooms: DashMap::with_capacity(16384),
66 rng: Mutex::new(StdRng::seed_from_u64(0)),
67 executor,
68 #[cfg(test)]
69 runtime: None,
70 })
71 }
72
73 #[cfg(test)]
74 pub fn reset(&self) {
75 self.rooms.clear();
76 }
77
78 pub async fn migrate(
79 &self,
80 migrations_path: &Path,
81 ignore_checksum_mismatch: bool,
82 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
83 let migrations = MigrationSource::resolve(migrations_path)
84 .await
85 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
86
87 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
88
89 connection.ensure_migrations_table().await?;
90 let applied_migrations: HashMap<_, _> = connection
91 .list_applied_migrations()
92 .await?
93 .into_iter()
94 .map(|m| (m.version, m))
95 .collect();
96
97 let mut new_migrations = Vec::new();
98 for migration in migrations {
99 match applied_migrations.get(&migration.version) {
100 Some(applied_migration) => {
101 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
102 {
103 Err(anyhow!(
104 "checksum mismatch for applied migration {}",
105 migration.description
106 ))?;
107 }
108 }
109 None => {
110 let elapsed = connection.apply(&migration).await?;
111 new_migrations.push((migration, elapsed));
112 }
113 }
114 }
115
116 Ok(new_migrations)
117 }
118
119 async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
120 where
121 F: Send + Fn(TransactionHandle) -> Fut,
122 Fut: Send + Future<Output = Result<T>>,
123 {
124 let body = async {
125 let mut i = 0;
126 loop {
127 let (tx, result) = self.with_transaction(&f).await?;
128 match result {
129 Ok(result) => match tx.commit().await.map_err(Into::into) {
130 Ok(()) => return Ok(result),
131 Err(error) => {
132 if !self.retry_on_serialization_error(&error, i).await {
133 return Err(error);
134 }
135 }
136 },
137 Err(error) => {
138 tx.rollback().await?;
139 if !self.retry_on_serialization_error(&error, i).await {
140 return Err(error);
141 }
142 }
143 }
144 i += 1;
145 }
146 };
147
148 self.run(body).await
149 }
150
151 async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
152 where
153 F: Send + Fn(TransactionHandle) -> Fut,
154 Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
155 {
156 let body = async {
157 let mut i = 0;
158 loop {
159 let (tx, result) = self.with_transaction(&f).await?;
160 match result {
161 Ok(Some((room_id, data))) => {
162 let lock = self.rooms.entry(room_id).or_default().clone();
163 let _guard = lock.lock_owned().await;
164 match tx.commit().await.map_err(Into::into) {
165 Ok(()) => {
166 return Ok(Some(RoomGuard {
167 data,
168 _guard,
169 _not_send: PhantomData,
170 }));
171 }
172 Err(error) => {
173 if !self.retry_on_serialization_error(&error, i).await {
174 return Err(error);
175 }
176 }
177 }
178 }
179 Ok(None) => match tx.commit().await.map_err(Into::into) {
180 Ok(()) => return Ok(None),
181 Err(error) => {
182 if !self.retry_on_serialization_error(&error, i).await {
183 return Err(error);
184 }
185 }
186 },
187 Err(error) => {
188 tx.rollback().await?;
189 if !self.retry_on_serialization_error(&error, i).await {
190 return Err(error);
191 }
192 }
193 }
194 i += 1;
195 }
196 };
197
198 self.run(body).await
199 }
200
201 async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
202 where
203 F: Send + Fn(TransactionHandle) -> Fut,
204 Fut: Send + Future<Output = Result<T>>,
205 {
206 let body = async {
207 let mut i = 0;
208 loop {
209 let lock = self.rooms.entry(room_id).or_default().clone();
210 let _guard = lock.lock_owned().await;
211 let (tx, result) = self.with_transaction(&f).await?;
212 match result {
213 Ok(data) => match tx.commit().await.map_err(Into::into) {
214 Ok(()) => {
215 return Ok(RoomGuard {
216 data,
217 _guard,
218 _not_send: PhantomData,
219 });
220 }
221 Err(error) => {
222 if !self.retry_on_serialization_error(&error, i).await {
223 return Err(error);
224 }
225 }
226 },
227 Err(error) => {
228 tx.rollback().await?;
229 if !self.retry_on_serialization_error(&error, i).await {
230 return Err(error);
231 }
232 }
233 }
234 i += 1;
235 }
236 };
237
238 self.run(body).await
239 }
240
241 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
242 where
243 F: Send + Fn(TransactionHandle) -> Fut,
244 Fut: Send + Future<Output = Result<T>>,
245 {
246 let tx = self
247 .pool
248 .begin_with_config(Some(IsolationLevel::Serializable), None)
249 .await?;
250
251 let mut tx = Arc::new(Some(tx));
252 let result = f(TransactionHandle(tx.clone())).await;
253 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
254 return Err(anyhow!(
255 "couldn't complete transaction because it's still in use"
256 ))?;
257 };
258
259 Ok((tx, result))
260 }
261
262 async fn run<F, T>(&self, future: F) -> Result<T>
263 where
264 F: Future<Output = Result<T>>,
265 {
266 #[cfg(test)]
267 {
268 if let Executor::Deterministic(executor) = &self.executor {
269 executor.simulate_random_delay().await;
270 }
271
272 self.runtime.as_ref().unwrap().block_on(future)
273 }
274
275 #[cfg(not(test))]
276 {
277 future.await
278 }
279 }
280
281 async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
282 // If the error is due to a failure to serialize concurrent transactions, then retry
283 // this transaction after a delay. With each subsequent retry, double the delay duration.
284 // Also vary the delay randomly in order to ensure different database connections retry
285 // at different times.
286 if is_serialization_error(error) {
287 let base_delay = 4_u64 << prev_attempt_count.min(16);
288 let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
289 log::info!(
290 "retrying transaction after serialization error. delay: {} ms.",
291 randomized_delay
292 );
293 self.executor
294 .sleep(Duration::from_millis(randomized_delay as u64))
295 .await;
296 true
297 } else {
298 false
299 }
300 }
301}
302
303fn is_serialization_error(error: &Error) -> bool {
304 const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
305 match error {
306 Error::Database(
307 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
308 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
309 ) if error
310 .as_database_error()
311 .and_then(|error| error.code())
312 .as_deref()
313 == Some(SERIALIZATION_FAILURE_CODE) =>
314 {
315 true
316 }
317 _ => false,
318 }
319}
320
321struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
322
323impl Deref for TransactionHandle {
324 type Target = DatabaseTransaction;
325
326 fn deref(&self) -> &Self::Target {
327 self.0.as_ref().as_ref().unwrap()
328 }
329}
330
331pub struct RoomGuard<T> {
332 data: T,
333 _guard: OwnedMutexGuard<()>,
334 _not_send: PhantomData<Rc<()>>,
335}
336
337impl<T> Deref for RoomGuard<T> {
338 type Target = T;
339
340 fn deref(&self) -> &T {
341 &self.data
342 }
343}
344
345impl<T> DerefMut for RoomGuard<T> {
346 fn deref_mut(&mut self) -> &mut T {
347 &mut self.data
348 }
349}
350
351impl<T> RoomGuard<T> {
352 pub fn into_inner(self) -> T {
353 self.data
354 }
355}
356
357#[derive(Clone, Debug, PartialEq, Eq)]
358pub enum Contact {
359 Accepted {
360 user_id: UserId,
361 should_notify: bool,
362 busy: bool,
363 },
364 Outgoing {
365 user_id: UserId,
366 },
367 Incoming {
368 user_id: UserId,
369 should_notify: bool,
370 },
371}
372
373impl Contact {
374 pub fn user_id(&self) -> UserId {
375 match self {
376 Contact::Accepted { user_id, .. } => *user_id,
377 Contact::Outgoing { user_id } => *user_id,
378 Contact::Incoming { user_id, .. } => *user_id,
379 }
380 }
381}
382
383#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
384pub struct Invite {
385 pub email_address: String,
386 pub email_confirmation_code: String,
387}
388
389#[derive(Clone, Debug, Deserialize)]
390pub struct NewSignup {
391 pub email_address: String,
392 pub platform_mac: bool,
393 pub platform_windows: bool,
394 pub platform_linux: bool,
395 pub editor_features: Vec<String>,
396 pub programming_languages: Vec<String>,
397 pub device_id: Option<String>,
398 pub added_to_mailing_list: bool,
399 pub created_at: Option<DateTime>,
400}
401
402#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
403pub struct WaitlistSummary {
404 pub count: i64,
405 pub linux_count: i64,
406 pub mac_count: i64,
407 pub windows_count: i64,
408 pub unknown_count: i64,
409}
410
411#[derive(Debug, Serialize, Deserialize)]
412pub struct NewUserParams {
413 pub github_login: String,
414 pub github_user_id: i32,
415 pub invite_count: i32,
416}
417
418#[derive(Debug)]
419pub struct NewUserResult {
420 pub user_id: UserId,
421 pub metrics_id: String,
422 pub inviting_user_id: Option<UserId>,
423 pub signup_device_id: Option<String>,
424}
425
426#[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)]
427pub struct Channel {
428 pub id: ChannelId,
429 pub name: String,
430}
431
432#[derive(Debug, PartialEq)]
433pub struct ChannelsForUser {
434 pub channels: ChannelGraph,
435 pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
436 pub channels_with_admin_privileges: HashSet<ChannelId>,
437}
438
439#[derive(Debug)]
440pub struct RejoinedChannelBuffer {
441 pub buffer: proto::RejoinedChannelBuffer,
442 pub old_connection_id: ConnectionId,
443}
444
445#[derive(Clone)]
446pub struct JoinRoom {
447 pub room: proto::Room,
448 pub channel_id: Option<ChannelId>,
449 pub channel_members: Vec<UserId>,
450}
451
452pub struct RejoinedRoom {
453 pub room: proto::Room,
454 pub rejoined_projects: Vec<RejoinedProject>,
455 pub reshared_projects: Vec<ResharedProject>,
456 pub channel_id: Option<ChannelId>,
457 pub channel_members: Vec<UserId>,
458}
459
460pub struct ResharedProject {
461 pub id: ProjectId,
462 pub old_connection_id: ConnectionId,
463 pub collaborators: Vec<ProjectCollaborator>,
464 pub worktrees: Vec<proto::WorktreeMetadata>,
465}
466
467pub struct RejoinedProject {
468 pub id: ProjectId,
469 pub old_connection_id: ConnectionId,
470 pub collaborators: Vec<ProjectCollaborator>,
471 pub worktrees: Vec<RejoinedWorktree>,
472 pub language_servers: Vec<proto::LanguageServer>,
473}
474
475#[derive(Debug)]
476pub struct RejoinedWorktree {
477 pub id: u64,
478 pub abs_path: String,
479 pub root_name: String,
480 pub visible: bool,
481 pub updated_entries: Vec<proto::Entry>,
482 pub removed_entries: Vec<u64>,
483 pub updated_repositories: Vec<proto::RepositoryEntry>,
484 pub removed_repositories: Vec<u64>,
485 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
486 pub settings_files: Vec<WorktreeSettingsFile>,
487 pub scan_id: u64,
488 pub completed_scan_id: u64,
489}
490
491pub struct LeftRoom {
492 pub room: proto::Room,
493 pub channel_id: Option<ChannelId>,
494 pub channel_members: Vec<UserId>,
495 pub left_projects: HashMap<ProjectId, LeftProject>,
496 pub canceled_calls_to_user_ids: Vec<UserId>,
497 pub deleted: bool,
498}
499
500pub struct RefreshedRoom {
501 pub room: proto::Room,
502 pub channel_id: Option<ChannelId>,
503 pub channel_members: Vec<UserId>,
504 pub stale_participant_user_ids: Vec<UserId>,
505 pub canceled_calls_to_user_ids: Vec<UserId>,
506}
507
508pub struct RefreshedChannelBuffer {
509 pub connection_ids: Vec<ConnectionId>,
510 pub removed_collaborators: Vec<proto::RemoveChannelBufferCollaborator>,
511}
512
513pub struct Project {
514 pub collaborators: Vec<ProjectCollaborator>,
515 pub worktrees: BTreeMap<u64, Worktree>,
516 pub language_servers: Vec<proto::LanguageServer>,
517}
518
519pub struct ProjectCollaborator {
520 pub connection_id: ConnectionId,
521 pub user_id: UserId,
522 pub replica_id: ReplicaId,
523 pub is_host: bool,
524}
525
526impl ProjectCollaborator {
527 pub fn to_proto(&self) -> proto::Collaborator {
528 proto::Collaborator {
529 peer_id: Some(self.connection_id.into()),
530 replica_id: self.replica_id.0 as u32,
531 user_id: self.user_id.to_proto(),
532 }
533 }
534}
535
536#[derive(Debug)]
537pub struct LeftProject {
538 pub id: ProjectId,
539 pub host_user_id: UserId,
540 pub host_connection_id: ConnectionId,
541 pub connection_ids: Vec<ConnectionId>,
542}
543
544pub struct Worktree {
545 pub id: u64,
546 pub abs_path: String,
547 pub root_name: String,
548 pub visible: bool,
549 pub entries: Vec<proto::Entry>,
550 pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
551 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
552 pub settings_files: Vec<WorktreeSettingsFile>,
553 pub scan_id: u64,
554 pub completed_scan_id: u64,
555}
556
557#[derive(Debug)]
558pub struct WorktreeSettingsFile {
559 pub path: String,
560 pub content: String,
561}