1#[cfg(test)]
2pub mod tests;
3
4#[cfg(test)]
5pub use tests::TestDb;
6
7mod ids;
8mod queries;
9mod tables;
10
11use crate::{executor::Executor, Error, Result};
12use anyhow::anyhow;
13use collections::{BTreeMap, HashMap, HashSet};
14use dashmap::DashMap;
15use futures::StreamExt;
16use rand::{prelude::StdRng, Rng, SeedableRng};
17use rpc::{
18 proto::{self},
19 ConnectionId,
20};
21use sea_orm::{
22 entity::prelude::*,
23 sea_query::{Alias, Expr, OnConflict, Query},
24 ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
25 FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
26 TransactionTrait,
27};
28use serde::{Deserialize, Serialize};
29use sqlx::{
30 migrate::{Migrate, Migration, MigrationSource},
31 Connection,
32};
33use std::{
34 fmt::Write as _,
35 future::Future,
36 marker::PhantomData,
37 ops::{Deref, DerefMut},
38 path::Path,
39 rc::Rc,
40 sync::Arc,
41 time::Duration,
42};
43use tables::*;
44use tokio::sync::{Mutex, OwnedMutexGuard};
45
46pub use ids::*;
47pub use sea_orm::ConnectOptions;
48pub use tables::user::Model as User;
49
50use self::queries::channels::ChannelGraph;
51
52pub struct Database {
53 options: ConnectOptions,
54 pool: DatabaseConnection,
55 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
56 rng: Mutex<StdRng>,
57 executor: Executor,
58 #[cfg(test)]
59 runtime: Option<tokio::runtime::Runtime>,
60}
61
62// The `Database` type has so many methods that its impl blocks are split into
63// separate files in the `queries` folder.
64impl Database {
65 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
66 sqlx::any::install_default_drivers();
67 Ok(Self {
68 options: options.clone(),
69 pool: sea_orm::Database::connect(options).await?,
70 rooms: DashMap::with_capacity(16384),
71 rng: Mutex::new(StdRng::seed_from_u64(0)),
72 executor,
73 #[cfg(test)]
74 runtime: None,
75 })
76 }
77
78 #[cfg(test)]
79 pub fn reset(&self) {
80 self.rooms.clear();
81 }
82
83 pub async fn migrate(
84 &self,
85 migrations_path: &Path,
86 ignore_checksum_mismatch: bool,
87 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
88 let migrations = MigrationSource::resolve(migrations_path)
89 .await
90 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
91
92 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
93
94 connection.ensure_migrations_table().await?;
95 let applied_migrations: HashMap<_, _> = connection
96 .list_applied_migrations()
97 .await?
98 .into_iter()
99 .map(|m| (m.version, m))
100 .collect();
101
102 let mut new_migrations = Vec::new();
103 for migration in migrations {
104 match applied_migrations.get(&migration.version) {
105 Some(applied_migration) => {
106 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
107 {
108 Err(anyhow!(
109 "checksum mismatch for applied migration {}",
110 migration.description
111 ))?;
112 }
113 }
114 None => {
115 let elapsed = connection.apply(&migration).await?;
116 new_migrations.push((migration, elapsed));
117 }
118 }
119 }
120
121 Ok(new_migrations)
122 }
123
124 pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
125 where
126 F: Send + Fn(TransactionHandle) -> Fut,
127 Fut: Send + Future<Output = Result<T>>,
128 {
129 let body = async {
130 let mut i = 0;
131 loop {
132 let (tx, result) = self.with_transaction(&f).await?;
133 match result {
134 Ok(result) => match tx.commit().await.map_err(Into::into) {
135 Ok(()) => return Ok(result),
136 Err(error) => {
137 if !self.retry_on_serialization_error(&error, i).await {
138 return Err(error);
139 }
140 }
141 },
142 Err(error) => {
143 tx.rollback().await?;
144 if !self.retry_on_serialization_error(&error, i).await {
145 return Err(error);
146 }
147 }
148 }
149 i += 1;
150 }
151 };
152
153 self.run(body).await
154 }
155
156 async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
157 where
158 F: Send + Fn(TransactionHandle) -> Fut,
159 Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
160 {
161 let body = async {
162 let mut i = 0;
163 loop {
164 let (tx, result) = self.with_transaction(&f).await?;
165 match result {
166 Ok(Some((room_id, data))) => {
167 let lock = self.rooms.entry(room_id).or_default().clone();
168 let _guard = lock.lock_owned().await;
169 match tx.commit().await.map_err(Into::into) {
170 Ok(()) => {
171 return Ok(Some(RoomGuard {
172 data,
173 _guard,
174 _not_send: PhantomData,
175 }));
176 }
177 Err(error) => {
178 if !self.retry_on_serialization_error(&error, i).await {
179 return Err(error);
180 }
181 }
182 }
183 }
184 Ok(None) => match tx.commit().await.map_err(Into::into) {
185 Ok(()) => return Ok(None),
186 Err(error) => {
187 if !self.retry_on_serialization_error(&error, i).await {
188 return Err(error);
189 }
190 }
191 },
192 Err(error) => {
193 tx.rollback().await?;
194 if !self.retry_on_serialization_error(&error, i).await {
195 return Err(error);
196 }
197 }
198 }
199 i += 1;
200 }
201 };
202
203 self.run(body).await
204 }
205
206 async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
207 where
208 F: Send + Fn(TransactionHandle) -> Fut,
209 Fut: Send + Future<Output = Result<T>>,
210 {
211 let body = async {
212 let mut i = 0;
213 loop {
214 let lock = self.rooms.entry(room_id).or_default().clone();
215 let _guard = lock.lock_owned().await;
216 let (tx, result) = self.with_transaction(&f).await?;
217 match result {
218 Ok(data) => match tx.commit().await.map_err(Into::into) {
219 Ok(()) => {
220 return Ok(RoomGuard {
221 data,
222 _guard,
223 _not_send: PhantomData,
224 });
225 }
226 Err(error) => {
227 if !self.retry_on_serialization_error(&error, i).await {
228 return Err(error);
229 }
230 }
231 },
232 Err(error) => {
233 tx.rollback().await?;
234 if !self.retry_on_serialization_error(&error, i).await {
235 return Err(error);
236 }
237 }
238 }
239 i += 1;
240 }
241 };
242
243 self.run(body).await
244 }
245
246 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
247 where
248 F: Send + Fn(TransactionHandle) -> Fut,
249 Fut: Send + Future<Output = Result<T>>,
250 {
251 let tx = self
252 .pool
253 .begin_with_config(Some(IsolationLevel::Serializable), None)
254 .await?;
255
256 let mut tx = Arc::new(Some(tx));
257 let result = f(TransactionHandle(tx.clone())).await;
258 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
259 return Err(anyhow!(
260 "couldn't complete transaction because it's still in use"
261 ))?;
262 };
263
264 Ok((tx, result))
265 }
266
267 async fn run<F, T>(&self, future: F) -> Result<T>
268 where
269 F: Future<Output = Result<T>>,
270 {
271 #[cfg(test)]
272 {
273 if let Executor::Deterministic(executor) = &self.executor {
274 executor.simulate_random_delay().await;
275 }
276
277 self.runtime.as_ref().unwrap().block_on(future)
278 }
279
280 #[cfg(not(test))]
281 {
282 future.await
283 }
284 }
285
286 async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
287 // If the error is due to a failure to serialize concurrent transactions, then retry
288 // this transaction after a delay. With each subsequent retry, double the delay duration.
289 // Also vary the delay randomly in order to ensure different database connections retry
290 // at different times.
291 if is_serialization_error(error) {
292 let base_delay = 4_u64 << prev_attempt_count.min(16);
293 let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
294 log::info!(
295 "retrying transaction after serialization error. delay: {} ms.",
296 randomized_delay
297 );
298 self.executor
299 .sleep(Duration::from_millis(randomized_delay as u64))
300 .await;
301 true
302 } else {
303 false
304 }
305 }
306}
307
308fn is_serialization_error(error: &Error) -> bool {
309 const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
310 match error {
311 Error::Database(
312 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
313 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
314 ) if error
315 .as_database_error()
316 .and_then(|error| error.code())
317 .as_deref()
318 == Some(SERIALIZATION_FAILURE_CODE) =>
319 {
320 true
321 }
322 _ => false,
323 }
324}
325
326pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
327
328impl Deref for TransactionHandle {
329 type Target = DatabaseTransaction;
330
331 fn deref(&self) -> &Self::Target {
332 self.0.as_ref().as_ref().unwrap()
333 }
334}
335
336pub struct RoomGuard<T> {
337 data: T,
338 _guard: OwnedMutexGuard<()>,
339 _not_send: PhantomData<Rc<()>>,
340}
341
342impl<T> Deref for RoomGuard<T> {
343 type Target = T;
344
345 fn deref(&self) -> &T {
346 &self.data
347 }
348}
349
350impl<T> DerefMut for RoomGuard<T> {
351 fn deref_mut(&mut self) -> &mut T {
352 &mut self.data
353 }
354}
355
356impl<T> RoomGuard<T> {
357 pub fn into_inner(self) -> T {
358 self.data
359 }
360}
361
362#[derive(Clone, Debug, PartialEq, Eq)]
363pub enum Contact {
364 Accepted {
365 user_id: UserId,
366 should_notify: bool,
367 busy: bool,
368 },
369 Outgoing {
370 user_id: UserId,
371 },
372 Incoming {
373 user_id: UserId,
374 should_notify: bool,
375 },
376}
377
378impl Contact {
379 pub fn user_id(&self) -> UserId {
380 match self {
381 Contact::Accepted { user_id, .. } => *user_id,
382 Contact::Outgoing { user_id } => *user_id,
383 Contact::Incoming { user_id, .. } => *user_id,
384 }
385 }
386}
387
388#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
389pub struct Invite {
390 pub email_address: String,
391 pub email_confirmation_code: String,
392}
393
394#[derive(Clone, Debug, Deserialize)]
395pub struct NewSignup {
396 pub email_address: String,
397 pub platform_mac: bool,
398 pub platform_windows: bool,
399 pub platform_linux: bool,
400 pub editor_features: Vec<String>,
401 pub programming_languages: Vec<String>,
402 pub device_id: Option<String>,
403 pub added_to_mailing_list: bool,
404 pub created_at: Option<DateTime>,
405}
406
407#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
408pub struct WaitlistSummary {
409 pub count: i64,
410 pub linux_count: i64,
411 pub mac_count: i64,
412 pub windows_count: i64,
413 pub unknown_count: i64,
414}
415
416#[derive(Debug, Serialize, Deserialize)]
417pub struct NewUserParams {
418 pub github_login: String,
419 pub github_user_id: i32,
420 pub invite_count: i32,
421}
422
423#[derive(Debug)]
424pub struct NewUserResult {
425 pub user_id: UserId,
426 pub metrics_id: String,
427 pub inviting_user_id: Option<UserId>,
428 pub signup_device_id: Option<String>,
429}
430
431#[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)]
432pub struct Channel {
433 pub id: ChannelId,
434 pub name: String,
435}
436
437#[derive(Debug, PartialEq)]
438pub struct ChannelsForUser {
439 pub channels: ChannelGraph,
440 pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
441 pub channels_with_admin_privileges: HashSet<ChannelId>,
442 pub unseen_buffer_changes: Vec<proto::UnseenChannelBufferChange>,
443 pub channel_messages: Vec<proto::UnseenChannelMessage>,
444}
445
446#[derive(Debug)]
447pub struct RejoinedChannelBuffer {
448 pub buffer: proto::RejoinedChannelBuffer,
449 pub old_connection_id: ConnectionId,
450}
451
452#[derive(Clone)]
453pub struct JoinRoom {
454 pub room: proto::Room,
455 pub channel_id: Option<ChannelId>,
456 pub channel_members: Vec<UserId>,
457}
458
459pub struct RejoinedRoom {
460 pub room: proto::Room,
461 pub rejoined_projects: Vec<RejoinedProject>,
462 pub reshared_projects: Vec<ResharedProject>,
463 pub channel_id: Option<ChannelId>,
464 pub channel_members: Vec<UserId>,
465}
466
467pub struct ResharedProject {
468 pub id: ProjectId,
469 pub old_connection_id: ConnectionId,
470 pub collaborators: Vec<ProjectCollaborator>,
471 pub worktrees: Vec<proto::WorktreeMetadata>,
472}
473
474pub struct RejoinedProject {
475 pub id: ProjectId,
476 pub old_connection_id: ConnectionId,
477 pub collaborators: Vec<ProjectCollaborator>,
478 pub worktrees: Vec<RejoinedWorktree>,
479 pub language_servers: Vec<proto::LanguageServer>,
480}
481
482#[derive(Debug)]
483pub struct RejoinedWorktree {
484 pub id: u64,
485 pub abs_path: String,
486 pub root_name: String,
487 pub visible: bool,
488 pub updated_entries: Vec<proto::Entry>,
489 pub removed_entries: Vec<u64>,
490 pub updated_repositories: Vec<proto::RepositoryEntry>,
491 pub removed_repositories: Vec<u64>,
492 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
493 pub settings_files: Vec<WorktreeSettingsFile>,
494 pub scan_id: u64,
495 pub completed_scan_id: u64,
496}
497
498pub struct LeftRoom {
499 pub room: proto::Room,
500 pub channel_id: Option<ChannelId>,
501 pub channel_members: Vec<UserId>,
502 pub left_projects: HashMap<ProjectId, LeftProject>,
503 pub canceled_calls_to_user_ids: Vec<UserId>,
504 pub deleted: bool,
505}
506
507pub struct RefreshedRoom {
508 pub room: proto::Room,
509 pub channel_id: Option<ChannelId>,
510 pub channel_members: Vec<UserId>,
511 pub stale_participant_user_ids: Vec<UserId>,
512 pub canceled_calls_to_user_ids: Vec<UserId>,
513}
514
515pub struct RefreshedChannelBuffer {
516 pub connection_ids: Vec<ConnectionId>,
517 pub collaborators: Vec<proto::Collaborator>,
518}
519
520pub struct Project {
521 pub collaborators: Vec<ProjectCollaborator>,
522 pub worktrees: BTreeMap<u64, Worktree>,
523 pub language_servers: Vec<proto::LanguageServer>,
524}
525
526pub struct ProjectCollaborator {
527 pub connection_id: ConnectionId,
528 pub user_id: UserId,
529 pub replica_id: ReplicaId,
530 pub is_host: bool,
531}
532
533impl ProjectCollaborator {
534 pub fn to_proto(&self) -> proto::Collaborator {
535 proto::Collaborator {
536 peer_id: Some(self.connection_id.into()),
537 replica_id: self.replica_id.0 as u32,
538 user_id: self.user_id.to_proto(),
539 }
540 }
541}
542
543#[derive(Debug)]
544pub struct LeftProject {
545 pub id: ProjectId,
546 pub host_user_id: UserId,
547 pub host_connection_id: ConnectionId,
548 pub connection_ids: Vec<ConnectionId>,
549}
550
551pub struct Worktree {
552 pub id: u64,
553 pub abs_path: String,
554 pub root_name: String,
555 pub visible: bool,
556 pub entries: Vec<proto::Entry>,
557 pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
558 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
559 pub settings_files: Vec<WorktreeSettingsFile>,
560 pub scan_id: u64,
561 pub completed_scan_id: u64,
562}
563
564#[derive(Debug)]
565pub struct WorktreeSettingsFile {
566 pub path: String,
567 pub content: String,
568}