1#[cfg(test)]
2pub mod tests;
3
4#[cfg(test)]
5pub use tests::TestDb;
6
7mod ids;
8mod queries;
9mod tables;
10
11use crate::{executor::Executor, Error, Result};
12use anyhow::anyhow;
13use collections::{BTreeMap, HashMap, HashSet};
14use dashmap::DashMap;
15use futures::StreamExt;
16use rand::{prelude::StdRng, Rng, SeedableRng};
17use rpc::{
18 proto::{self},
19 ConnectionId,
20};
21use sea_orm::{
22 entity::prelude::*,
23 sea_query::{Alias, Expr, OnConflict},
24 ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
25 FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
26 TransactionTrait,
27};
28use serde::{Deserialize, Serialize};
29use sqlx::{
30 migrate::{Migrate, Migration, MigrationSource},
31 Connection,
32};
33use std::{
34 fmt::Write as _,
35 future::Future,
36 marker::PhantomData,
37 ops::{Deref, DerefMut},
38 path::Path,
39 rc::Rc,
40 sync::Arc,
41 time::Duration,
42};
43use tables::*;
44use tokio::sync::{Mutex, OwnedMutexGuard};
45
46pub use ids::*;
47pub use sea_orm::ConnectOptions;
48pub use tables::user::Model as User;
49
50use self::queries::channels::ChannelGraph;
51
52pub struct Database {
53 options: ConnectOptions,
54 pool: DatabaseConnection,
55 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
56 rng: Mutex<StdRng>,
57 executor: Executor,
58 notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
59 notification_kinds_by_name: HashMap<String, NotificationKindId>,
60 #[cfg(test)]
61 runtime: Option<tokio::runtime::Runtime>,
62}
63
64// The `Database` type has so many methods that its impl blocks are split into
65// separate files in the `queries` folder.
66impl Database {
67 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
68 sqlx::any::install_default_drivers();
69 Ok(Self {
70 options: options.clone(),
71 pool: sea_orm::Database::connect(options).await?,
72 rooms: DashMap::with_capacity(16384),
73 rng: Mutex::new(StdRng::seed_from_u64(0)),
74 notification_kinds_by_id: HashMap::default(),
75 notification_kinds_by_name: HashMap::default(),
76 executor,
77 #[cfg(test)]
78 runtime: None,
79 })
80 }
81
82 #[cfg(test)]
83 pub fn reset(&self) {
84 self.rooms.clear();
85 }
86
87 pub async fn migrate(
88 &self,
89 migrations_path: &Path,
90 ignore_checksum_mismatch: bool,
91 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
92 let migrations = MigrationSource::resolve(migrations_path)
93 .await
94 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
95
96 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
97
98 connection.ensure_migrations_table().await?;
99 let applied_migrations: HashMap<_, _> = connection
100 .list_applied_migrations()
101 .await?
102 .into_iter()
103 .map(|m| (m.version, m))
104 .collect();
105
106 let mut new_migrations = Vec::new();
107 for migration in migrations {
108 match applied_migrations.get(&migration.version) {
109 Some(applied_migration) => {
110 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
111 {
112 Err(anyhow!(
113 "checksum mismatch for applied migration {}",
114 migration.description
115 ))?;
116 }
117 }
118 None => {
119 let elapsed = connection.apply(&migration).await?;
120 new_migrations.push((migration, elapsed));
121 }
122 }
123 }
124
125 Ok(new_migrations)
126 }
127
128 pub async fn initialize_static_data(&mut self) -> Result<()> {
129 self.initialize_notification_enum().await?;
130 Ok(())
131 }
132
133 pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
134 where
135 F: Send + Fn(TransactionHandle) -> Fut,
136 Fut: Send + Future<Output = Result<T>>,
137 {
138 let body = async {
139 let mut i = 0;
140 loop {
141 let (tx, result) = self.with_transaction(&f).await?;
142 match result {
143 Ok(result) => match tx.commit().await.map_err(Into::into) {
144 Ok(()) => return Ok(result),
145 Err(error) => {
146 if !self.retry_on_serialization_error(&error, i).await {
147 return Err(error);
148 }
149 }
150 },
151 Err(error) => {
152 tx.rollback().await?;
153 if !self.retry_on_serialization_error(&error, i).await {
154 return Err(error);
155 }
156 }
157 }
158 i += 1;
159 }
160 };
161
162 self.run(body).await
163 }
164
165 async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
166 where
167 F: Send + Fn(TransactionHandle) -> Fut,
168 Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
169 {
170 let body = async {
171 let mut i = 0;
172 loop {
173 let (tx, result) = self.with_transaction(&f).await?;
174 match result {
175 Ok(Some((room_id, data))) => {
176 let lock = self.rooms.entry(room_id).or_default().clone();
177 let _guard = lock.lock_owned().await;
178 match tx.commit().await.map_err(Into::into) {
179 Ok(()) => {
180 return Ok(Some(RoomGuard {
181 data,
182 _guard,
183 _not_send: PhantomData,
184 }));
185 }
186 Err(error) => {
187 if !self.retry_on_serialization_error(&error, i).await {
188 return Err(error);
189 }
190 }
191 }
192 }
193 Ok(None) => match tx.commit().await.map_err(Into::into) {
194 Ok(()) => return Ok(None),
195 Err(error) => {
196 if !self.retry_on_serialization_error(&error, i).await {
197 return Err(error);
198 }
199 }
200 },
201 Err(error) => {
202 tx.rollback().await?;
203 if !self.retry_on_serialization_error(&error, i).await {
204 return Err(error);
205 }
206 }
207 }
208 i += 1;
209 }
210 };
211
212 self.run(body).await
213 }
214
215 async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
216 where
217 F: Send + Fn(TransactionHandle) -> Fut,
218 Fut: Send + Future<Output = Result<T>>,
219 {
220 let body = async {
221 let mut i = 0;
222 loop {
223 let lock = self.rooms.entry(room_id).or_default().clone();
224 let _guard = lock.lock_owned().await;
225 let (tx, result) = self.with_transaction(&f).await?;
226 match result {
227 Ok(data) => match tx.commit().await.map_err(Into::into) {
228 Ok(()) => {
229 return Ok(RoomGuard {
230 data,
231 _guard,
232 _not_send: PhantomData,
233 });
234 }
235 Err(error) => {
236 if !self.retry_on_serialization_error(&error, i).await {
237 return Err(error);
238 }
239 }
240 },
241 Err(error) => {
242 tx.rollback().await?;
243 if !self.retry_on_serialization_error(&error, i).await {
244 return Err(error);
245 }
246 }
247 }
248 i += 1;
249 }
250 };
251
252 self.run(body).await
253 }
254
255 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
256 where
257 F: Send + Fn(TransactionHandle) -> Fut,
258 Fut: Send + Future<Output = Result<T>>,
259 {
260 let tx = self
261 .pool
262 .begin_with_config(Some(IsolationLevel::Serializable), None)
263 .await?;
264
265 let mut tx = Arc::new(Some(tx));
266 let result = f(TransactionHandle(tx.clone())).await;
267 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
268 return Err(anyhow!(
269 "couldn't complete transaction because it's still in use"
270 ))?;
271 };
272
273 Ok((tx, result))
274 }
275
276 async fn run<F, T>(&self, future: F) -> Result<T>
277 where
278 F: Future<Output = Result<T>>,
279 {
280 #[cfg(test)]
281 {
282 if let Executor::Deterministic(executor) = &self.executor {
283 executor.simulate_random_delay().await;
284 }
285
286 self.runtime.as_ref().unwrap().block_on(future)
287 }
288
289 #[cfg(not(test))]
290 {
291 future.await
292 }
293 }
294
295 async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
296 // If the error is due to a failure to serialize concurrent transactions, then retry
297 // this transaction after a delay. With each subsequent retry, double the delay duration.
298 // Also vary the delay randomly in order to ensure different database connections retry
299 // at different times.
300 if is_serialization_error(error) {
301 let base_delay = 4_u64 << prev_attempt_count.min(16);
302 let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
303 log::info!(
304 "retrying transaction after serialization error. delay: {} ms.",
305 randomized_delay
306 );
307 self.executor
308 .sleep(Duration::from_millis(randomized_delay as u64))
309 .await;
310 true
311 } else {
312 false
313 }
314 }
315}
316
317fn is_serialization_error(error: &Error) -> bool {
318 const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
319 match error {
320 Error::Database(
321 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
322 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
323 ) if error
324 .as_database_error()
325 .and_then(|error| error.code())
326 .as_deref()
327 == Some(SERIALIZATION_FAILURE_CODE) =>
328 {
329 true
330 }
331 _ => false,
332 }
333}
334
335pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
336
337impl Deref for TransactionHandle {
338 type Target = DatabaseTransaction;
339
340 fn deref(&self) -> &Self::Target {
341 self.0.as_ref().as_ref().unwrap()
342 }
343}
344
345pub struct RoomGuard<T> {
346 data: T,
347 _guard: OwnedMutexGuard<()>,
348 _not_send: PhantomData<Rc<()>>,
349}
350
351impl<T> Deref for RoomGuard<T> {
352 type Target = T;
353
354 fn deref(&self) -> &T {
355 &self.data
356 }
357}
358
359impl<T> DerefMut for RoomGuard<T> {
360 fn deref_mut(&mut self) -> &mut T {
361 &mut self.data
362 }
363}
364
365impl<T> RoomGuard<T> {
366 pub fn into_inner(self) -> T {
367 self.data
368 }
369}
370
371#[derive(Clone, Debug, PartialEq, Eq)]
372pub enum Contact {
373 Accepted { user_id: UserId, busy: bool },
374 Outgoing { user_id: UserId },
375 Incoming { user_id: UserId },
376}
377
378impl Contact {
379 pub fn user_id(&self) -> UserId {
380 match self {
381 Contact::Accepted { user_id, .. } => *user_id,
382 Contact::Outgoing { user_id } => *user_id,
383 Contact::Incoming { user_id, .. } => *user_id,
384 }
385 }
386}
387
388#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
389pub struct Invite {
390 pub email_address: String,
391 pub email_confirmation_code: String,
392}
393
394#[derive(Clone, Debug, Deserialize)]
395pub struct NewSignup {
396 pub email_address: String,
397 pub platform_mac: bool,
398 pub platform_windows: bool,
399 pub platform_linux: bool,
400 pub editor_features: Vec<String>,
401 pub programming_languages: Vec<String>,
402 pub device_id: Option<String>,
403 pub added_to_mailing_list: bool,
404 pub created_at: Option<DateTime>,
405}
406
407#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
408pub struct WaitlistSummary {
409 pub count: i64,
410 pub linux_count: i64,
411 pub mac_count: i64,
412 pub windows_count: i64,
413 pub unknown_count: i64,
414}
415
416#[derive(Debug, Serialize, Deserialize)]
417pub struct NewUserParams {
418 pub github_login: String,
419 pub github_user_id: i32,
420 pub invite_count: i32,
421}
422
423#[derive(Debug)]
424pub struct NewUserResult {
425 pub user_id: UserId,
426 pub metrics_id: String,
427 pub inviting_user_id: Option<UserId>,
428 pub signup_device_id: Option<String>,
429}
430
431#[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)]
432pub struct Channel {
433 pub id: ChannelId,
434 pub name: String,
435}
436
437#[derive(Debug, PartialEq)]
438pub struct ChannelsForUser {
439 pub channels: ChannelGraph,
440 pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
441 pub channels_with_admin_privileges: HashSet<ChannelId>,
442 pub unseen_buffer_changes: Vec<proto::UnseenChannelBufferChange>,
443 pub channel_messages: Vec<proto::UnseenChannelMessage>,
444}
445
446#[derive(Debug)]
447pub struct RejoinedChannelBuffer {
448 pub buffer: proto::RejoinedChannelBuffer,
449 pub old_connection_id: ConnectionId,
450}
451
452#[derive(Clone)]
453pub struct JoinRoom {
454 pub room: proto::Room,
455 pub channel_id: Option<ChannelId>,
456 pub channel_members: Vec<UserId>,
457}
458
459pub struct RejoinedRoom {
460 pub room: proto::Room,
461 pub rejoined_projects: Vec<RejoinedProject>,
462 pub reshared_projects: Vec<ResharedProject>,
463 pub channel_id: Option<ChannelId>,
464 pub channel_members: Vec<UserId>,
465}
466
467pub struct ResharedProject {
468 pub id: ProjectId,
469 pub old_connection_id: ConnectionId,
470 pub collaborators: Vec<ProjectCollaborator>,
471 pub worktrees: Vec<proto::WorktreeMetadata>,
472}
473
474pub struct RejoinedProject {
475 pub id: ProjectId,
476 pub old_connection_id: ConnectionId,
477 pub collaborators: Vec<ProjectCollaborator>,
478 pub worktrees: Vec<RejoinedWorktree>,
479 pub language_servers: Vec<proto::LanguageServer>,
480}
481
482#[derive(Debug)]
483pub struct RejoinedWorktree {
484 pub id: u64,
485 pub abs_path: String,
486 pub root_name: String,
487 pub visible: bool,
488 pub updated_entries: Vec<proto::Entry>,
489 pub removed_entries: Vec<u64>,
490 pub updated_repositories: Vec<proto::RepositoryEntry>,
491 pub removed_repositories: Vec<u64>,
492 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
493 pub settings_files: Vec<WorktreeSettingsFile>,
494 pub scan_id: u64,
495 pub completed_scan_id: u64,
496}
497
498pub struct LeftRoom {
499 pub room: proto::Room,
500 pub channel_id: Option<ChannelId>,
501 pub channel_members: Vec<UserId>,
502 pub left_projects: HashMap<ProjectId, LeftProject>,
503 pub canceled_calls_to_user_ids: Vec<UserId>,
504 pub deleted: bool,
505}
506
507pub struct RefreshedRoom {
508 pub room: proto::Room,
509 pub channel_id: Option<ChannelId>,
510 pub channel_members: Vec<UserId>,
511 pub stale_participant_user_ids: Vec<UserId>,
512 pub canceled_calls_to_user_ids: Vec<UserId>,
513}
514
515pub struct RefreshedChannelBuffer {
516 pub connection_ids: Vec<ConnectionId>,
517 pub collaborators: Vec<proto::Collaborator>,
518}
519
520pub struct Project {
521 pub collaborators: Vec<ProjectCollaborator>,
522 pub worktrees: BTreeMap<u64, Worktree>,
523 pub language_servers: Vec<proto::LanguageServer>,
524}
525
526pub struct ProjectCollaborator {
527 pub connection_id: ConnectionId,
528 pub user_id: UserId,
529 pub replica_id: ReplicaId,
530 pub is_host: bool,
531}
532
533impl ProjectCollaborator {
534 pub fn to_proto(&self) -> proto::Collaborator {
535 proto::Collaborator {
536 peer_id: Some(self.connection_id.into()),
537 replica_id: self.replica_id.0 as u32,
538 user_id: self.user_id.to_proto(),
539 }
540 }
541}
542
543#[derive(Debug)]
544pub struct LeftProject {
545 pub id: ProjectId,
546 pub host_user_id: UserId,
547 pub host_connection_id: ConnectionId,
548 pub connection_ids: Vec<ConnectionId>,
549}
550
551pub struct Worktree {
552 pub id: u64,
553 pub abs_path: String,
554 pub root_name: String,
555 pub visible: bool,
556 pub entries: Vec<proto::Entry>,
557 pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
558 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
559 pub settings_files: Vec<WorktreeSettingsFile>,
560 pub scan_id: u64,
561 pub completed_scan_id: u64,
562}
563
564#[derive(Debug)]
565pub struct WorktreeSettingsFile {
566 pub path: String,
567 pub content: String,
568}