1#[cfg(test)]
2pub mod tests;
3
4#[cfg(test)]
5pub use tests::TestDb;
6
7mod ids;
8mod queries;
9mod tables;
10
11use crate::{executor::Executor, Error, Result};
12use anyhow::anyhow;
13use collections::{BTreeMap, HashMap, HashSet};
14use dashmap::DashMap;
15use futures::StreamExt;
16use rand::{prelude::StdRng, Rng, SeedableRng};
17use rpc::{
18 proto::{self},
19 ConnectionId,
20};
21use sea_orm::{
22 entity::prelude::*,
23 sea_query::{Alias, Expr, OnConflict},
24 ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
25 FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
26 TransactionTrait,
27};
28use serde::{Deserialize, Serialize};
29use sqlx::{
30 migrate::{Migrate, Migration, MigrationSource},
31 Connection,
32};
33use std::{
34 fmt::Write as _,
35 future::Future,
36 marker::PhantomData,
37 ops::{Deref, DerefMut},
38 path::Path,
39 rc::Rc,
40 sync::Arc,
41 time::Duration,
42};
43use tables::*;
44use tokio::sync::{Mutex, OwnedMutexGuard};
45
46pub use ids::*;
47pub use sea_orm::ConnectOptions;
48pub use tables::user::Model as User;
49
50/// Database gives you a handle that lets you access the database.
51/// It handles pooling internally.
52pub struct Database {
53 options: ConnectOptions,
54 pool: DatabaseConnection,
55 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
56 rng: Mutex<StdRng>,
57 executor: Executor,
58 notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
59 notification_kinds_by_name: HashMap<String, NotificationKindId>,
60 #[cfg(test)]
61 runtime: Option<tokio::runtime::Runtime>,
62}
63
64// The `Database` type has so many methods that its impl blocks are split into
65// separate files in the `queries` folder.
66impl Database {
67 /// Connects to the database with the given options
68 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
69 sqlx::any::install_default_drivers();
70 Ok(Self {
71 options: options.clone(),
72 pool: sea_orm::Database::connect(options).await?,
73 rooms: DashMap::with_capacity(16384),
74 rng: Mutex::new(StdRng::seed_from_u64(0)),
75 notification_kinds_by_id: HashMap::default(),
76 notification_kinds_by_name: HashMap::default(),
77 executor,
78 #[cfg(test)]
79 runtime: None,
80 })
81 }
82
83 #[cfg(test)]
84 pub fn reset(&self) {
85 self.rooms.clear();
86 }
87
88 /// Runs the database migrations.
89 pub async fn migrate(
90 &self,
91 migrations_path: &Path,
92 ignore_checksum_mismatch: bool,
93 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
94 let migrations = MigrationSource::resolve(migrations_path)
95 .await
96 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
97
98 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
99
100 connection.ensure_migrations_table().await?;
101 let applied_migrations: HashMap<_, _> = connection
102 .list_applied_migrations()
103 .await?
104 .into_iter()
105 .map(|m| (m.version, m))
106 .collect();
107
108 let mut new_migrations = Vec::new();
109 for migration in migrations {
110 match applied_migrations.get(&migration.version) {
111 Some(applied_migration) => {
112 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
113 {
114 Err(anyhow!(
115 "checksum mismatch for applied migration {}",
116 migration.description
117 ))?;
118 }
119 }
120 None => {
121 let elapsed = connection.apply(&migration).await?;
122 new_migrations.push((migration, elapsed));
123 }
124 }
125 }
126
127 Ok(new_migrations)
128 }
129
130 /// Initializes static data that resides in the database by upserting it.
131 pub async fn initialize_static_data(&mut self) -> Result<()> {
132 self.initialize_notification_kinds().await?;
133 Ok(())
134 }
135
136 /// Transaction runs things in a transaction. If you want to call other methods
137 /// and pass the transaction around you need to reborrow the transaction at each
138 /// call site with: `&*tx`.
139 pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
140 where
141 F: Send + Fn(TransactionHandle) -> Fut,
142 Fut: Send + Future<Output = Result<T>>,
143 {
144 let body = async {
145 let mut i = 0;
146 loop {
147 let (tx, result) = self.with_transaction(&f).await?;
148 match result {
149 Ok(result) => match tx.commit().await.map_err(Into::into) {
150 Ok(()) => return Ok(result),
151 Err(error) => {
152 if !self.retry_on_serialization_error(&error, i).await {
153 return Err(error);
154 }
155 }
156 },
157 Err(error) => {
158 tx.rollback().await?;
159 if !self.retry_on_serialization_error(&error, i).await {
160 return Err(error);
161 }
162 }
163 }
164 i += 1;
165 }
166 };
167
168 self.run(body).await
169 }
170
171 /// The same as room_transaction, but if you need to only optionally return a Room.
172 async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
173 where
174 F: Send + Fn(TransactionHandle) -> Fut,
175 Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
176 {
177 let body = async {
178 let mut i = 0;
179 loop {
180 let (tx, result) = self.with_transaction(&f).await?;
181 match result {
182 Ok(Some((room_id, data))) => {
183 let lock = self.rooms.entry(room_id).or_default().clone();
184 let _guard = lock.lock_owned().await;
185 match tx.commit().await.map_err(Into::into) {
186 Ok(()) => {
187 return Ok(Some(RoomGuard {
188 data,
189 _guard,
190 _not_send: PhantomData,
191 }));
192 }
193 Err(error) => {
194 if !self.retry_on_serialization_error(&error, i).await {
195 return Err(error);
196 }
197 }
198 }
199 }
200 Ok(None) => match tx.commit().await.map_err(Into::into) {
201 Ok(()) => return Ok(None),
202 Err(error) => {
203 if !self.retry_on_serialization_error(&error, i).await {
204 return Err(error);
205 }
206 }
207 },
208 Err(error) => {
209 tx.rollback().await?;
210 if !self.retry_on_serialization_error(&error, i).await {
211 return Err(error);
212 }
213 }
214 }
215 i += 1;
216 }
217 };
218
219 self.run(body).await
220 }
221
222 /// room_transaction runs the block in a transaction. It returns a RoomGuard, that keeps
223 /// the database locked until it is dropped. This ensures that updates sent to clients are
224 /// properly serialized with respect to database changes.
225 async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
226 where
227 F: Send + Fn(TransactionHandle) -> Fut,
228 Fut: Send + Future<Output = Result<T>>,
229 {
230 let body = async {
231 let mut i = 0;
232 loop {
233 let lock = self.rooms.entry(room_id).or_default().clone();
234 let _guard = lock.lock_owned().await;
235 let (tx, result) = self.with_transaction(&f).await?;
236 match result {
237 Ok(data) => match tx.commit().await.map_err(Into::into) {
238 Ok(()) => {
239 return Ok(RoomGuard {
240 data,
241 _guard,
242 _not_send: PhantomData,
243 });
244 }
245 Err(error) => {
246 if !self.retry_on_serialization_error(&error, i).await {
247 return Err(error);
248 }
249 }
250 },
251 Err(error) => {
252 tx.rollback().await?;
253 if !self.retry_on_serialization_error(&error, i).await {
254 return Err(error);
255 }
256 }
257 }
258 i += 1;
259 }
260 };
261
262 self.run(body).await
263 }
264
265 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
266 where
267 F: Send + Fn(TransactionHandle) -> Fut,
268 Fut: Send + Future<Output = Result<T>>,
269 {
270 let tx = self
271 .pool
272 .begin_with_config(Some(IsolationLevel::Serializable), None)
273 .await?;
274
275 let mut tx = Arc::new(Some(tx));
276 let result = f(TransactionHandle(tx.clone())).await;
277 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
278 return Err(anyhow!(
279 "couldn't complete transaction because it's still in use"
280 ))?;
281 };
282
283 Ok((tx, result))
284 }
285
286 async fn run<F, T>(&self, future: F) -> Result<T>
287 where
288 F: Future<Output = Result<T>>,
289 {
290 #[cfg(test)]
291 {
292 if let Executor::Deterministic(executor) = &self.executor {
293 executor.simulate_random_delay().await;
294 }
295
296 self.runtime.as_ref().unwrap().block_on(future)
297 }
298
299 #[cfg(not(test))]
300 {
301 future.await
302 }
303 }
304
305 async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
306 // If the error is due to a failure to serialize concurrent transactions, then retry
307 // this transaction after a delay. With each subsequent retry, double the delay duration.
308 // Also vary the delay randomly in order to ensure different database connections retry
309 // at different times.
310 if is_serialization_error(error) {
311 let base_delay = 4_u64 << prev_attempt_count.min(16);
312 let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
313 log::info!(
314 "retrying transaction after serialization error. delay: {} ms.",
315 randomized_delay
316 );
317 self.executor
318 .sleep(Duration::from_millis(randomized_delay as u64))
319 .await;
320 true
321 } else {
322 false
323 }
324 }
325}
326
327fn is_serialization_error(error: &Error) -> bool {
328 const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
329 match error {
330 Error::Database(
331 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
332 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
333 ) if error
334 .as_database_error()
335 .and_then(|error| error.code())
336 .as_deref()
337 == Some(SERIALIZATION_FAILURE_CODE) =>
338 {
339 true
340 }
341 _ => false,
342 }
343}
344
345/// A handle to a [`DatabaseTransaction`].
346pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
347
348impl Deref for TransactionHandle {
349 type Target = DatabaseTransaction;
350
351 fn deref(&self) -> &Self::Target {
352 self.0.as_ref().as_ref().unwrap()
353 }
354}
355
356/// [`RoomGuard`] keeps a database transaction alive until it is dropped.
357/// so that updates to rooms are serialized.
358pub struct RoomGuard<T> {
359 data: T,
360 _guard: OwnedMutexGuard<()>,
361 _not_send: PhantomData<Rc<()>>,
362}
363
364impl<T> Deref for RoomGuard<T> {
365 type Target = T;
366
367 fn deref(&self) -> &T {
368 &self.data
369 }
370}
371
372impl<T> DerefMut for RoomGuard<T> {
373 fn deref_mut(&mut self) -> &mut T {
374 &mut self.data
375 }
376}
377
378impl<T> RoomGuard<T> {
379 /// Returns the inner value of the guard.
380 pub fn into_inner(self) -> T {
381 self.data
382 }
383}
384
385#[derive(Clone, Debug, PartialEq, Eq)]
386pub enum Contact {
387 Accepted { user_id: UserId, busy: bool },
388 Outgoing { user_id: UserId },
389 Incoming { user_id: UserId },
390}
391
392impl Contact {
393 pub fn user_id(&self) -> UserId {
394 match self {
395 Contact::Accepted { user_id, .. } => *user_id,
396 Contact::Outgoing { user_id } => *user_id,
397 Contact::Incoming { user_id, .. } => *user_id,
398 }
399 }
400}
401
402pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
403
404pub struct CreatedChannelMessage {
405 pub message_id: MessageId,
406 pub participant_connection_ids: Vec<ConnectionId>,
407 pub channel_members: Vec<UserId>,
408 pub notifications: NotificationBatch,
409}
410
411#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
412pub struct Invite {
413 pub email_address: String,
414 pub email_confirmation_code: String,
415}
416
417#[derive(Clone, Debug, Deserialize)]
418pub struct NewSignup {
419 pub email_address: String,
420 pub platform_mac: bool,
421 pub platform_windows: bool,
422 pub platform_linux: bool,
423 pub editor_features: Vec<String>,
424 pub programming_languages: Vec<String>,
425 pub device_id: Option<String>,
426 pub added_to_mailing_list: bool,
427 pub created_at: Option<DateTime>,
428}
429
430#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)]
431pub struct WaitlistSummary {
432 pub count: i64,
433 pub linux_count: i64,
434 pub mac_count: i64,
435 pub windows_count: i64,
436 pub unknown_count: i64,
437}
438
439/// The parameters to create a new user.
440#[derive(Debug, Serialize, Deserialize)]
441pub struct NewUserParams {
442 pub github_login: String,
443 pub github_user_id: i32,
444}
445
446/// The result of creating a new user.
447#[derive(Debug)]
448pub struct NewUserResult {
449 pub user_id: UserId,
450 pub metrics_id: String,
451 pub inviting_user_id: Option<UserId>,
452 pub signup_device_id: Option<String>,
453}
454
455/// The result of moving a channel.
456#[derive(Debug)]
457pub struct MoveChannelResult {
458 pub participants_to_update: HashMap<UserId, ChannelsForUser>,
459 pub participants_to_remove: HashSet<UserId>,
460 pub moved_channels: HashSet<ChannelId>,
461}
462
463/// The result of renaming a channel.
464#[derive(Debug)]
465pub struct RenameChannelResult {
466 pub channel: Channel,
467 pub participants_to_update: HashMap<UserId, Channel>,
468}
469
470/// The result of creating a channel.
471#[derive(Debug)]
472pub struct CreateChannelResult {
473 pub channel: Channel,
474 pub participants_to_update: Vec<(UserId, ChannelsForUser)>,
475}
476
477/// The result of setting a channel's visibility.
478#[derive(Debug)]
479pub struct SetChannelVisibilityResult {
480 pub participants_to_update: HashMap<UserId, ChannelsForUser>,
481 pub participants_to_remove: HashSet<UserId>,
482 pub channels_to_remove: Vec<ChannelId>,
483}
484
485/// The result of updating a channel membership.
486#[derive(Debug)]
487pub struct MembershipUpdated {
488 pub channel_id: ChannelId,
489 pub new_channels: ChannelsForUser,
490 pub removed_channels: Vec<ChannelId>,
491}
492
493/// The result of setting a member's role.
494#[derive(Debug)]
495pub enum SetMemberRoleResult {
496 InviteUpdated(Channel),
497 MembershipUpdated(MembershipUpdated),
498}
499
500/// The result of inviting a member to a channel.
501#[derive(Debug)]
502pub struct InviteMemberResult {
503 pub channel: Channel,
504 pub notifications: NotificationBatch,
505}
506
507#[derive(Debug)]
508pub struct RespondToChannelInvite {
509 pub membership_update: Option<MembershipUpdated>,
510 pub notifications: NotificationBatch,
511}
512
513#[derive(Debug)]
514pub struct RemoveChannelMemberResult {
515 pub membership_update: MembershipUpdated,
516 pub notification_id: Option<NotificationId>,
517}
518
519#[derive(Debug, PartialEq, Eq, Hash)]
520pub struct Channel {
521 pub id: ChannelId,
522 pub name: String,
523 pub visibility: ChannelVisibility,
524 pub role: ChannelRole,
525 /// parent_path is the channel ids from the root to this one (not including this one)
526 pub parent_path: Vec<ChannelId>,
527}
528
529impl Channel {
530 fn from_model(value: channel::Model, role: ChannelRole) -> Self {
531 Channel {
532 id: value.id,
533 visibility: value.visibility,
534 name: value.clone().name,
535 role,
536 parent_path: value.ancestors().collect(),
537 }
538 }
539
540 pub fn to_proto(&self) -> proto::Channel {
541 proto::Channel {
542 id: self.id.to_proto(),
543 name: self.name.clone(),
544 visibility: self.visibility.into(),
545 role: self.role.into(),
546 parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(),
547 }
548 }
549}
550
551#[derive(Debug, PartialEq, Eq, Hash)]
552pub struct ChannelMember {
553 pub role: ChannelRole,
554 pub user_id: UserId,
555 pub kind: proto::channel_member::Kind,
556}
557
558impl ChannelMember {
559 pub fn to_proto(&self) -> proto::ChannelMember {
560 proto::ChannelMember {
561 role: self.role.into(),
562 user_id: self.user_id.to_proto(),
563 kind: self.kind.into(),
564 }
565 }
566}
567
568#[derive(Debug, PartialEq)]
569pub struct ChannelsForUser {
570 pub channels: Vec<Channel>,
571 pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
572 pub unseen_buffer_changes: Vec<proto::UnseenChannelBufferChange>,
573 pub channel_messages: Vec<proto::UnseenChannelMessage>,
574}
575
576#[derive(Debug)]
577pub struct RejoinedChannelBuffer {
578 pub buffer: proto::RejoinedChannelBuffer,
579 pub old_connection_id: ConnectionId,
580}
581
582#[derive(Clone)]
583pub struct JoinRoom {
584 pub room: proto::Room,
585 pub channel_id: Option<ChannelId>,
586 pub channel_members: Vec<UserId>,
587}
588
589pub struct RejoinedRoom {
590 pub room: proto::Room,
591 pub rejoined_projects: Vec<RejoinedProject>,
592 pub reshared_projects: Vec<ResharedProject>,
593 pub channel_id: Option<ChannelId>,
594 pub channel_members: Vec<UserId>,
595}
596
597pub struct ResharedProject {
598 pub id: ProjectId,
599 pub old_connection_id: ConnectionId,
600 pub collaborators: Vec<ProjectCollaborator>,
601 pub worktrees: Vec<proto::WorktreeMetadata>,
602}
603
604pub struct RejoinedProject {
605 pub id: ProjectId,
606 pub old_connection_id: ConnectionId,
607 pub collaborators: Vec<ProjectCollaborator>,
608 pub worktrees: Vec<RejoinedWorktree>,
609 pub language_servers: Vec<proto::LanguageServer>,
610}
611
612#[derive(Debug)]
613pub struct RejoinedWorktree {
614 pub id: u64,
615 pub abs_path: String,
616 pub root_name: String,
617 pub visible: bool,
618 pub updated_entries: Vec<proto::Entry>,
619 pub removed_entries: Vec<u64>,
620 pub updated_repositories: Vec<proto::RepositoryEntry>,
621 pub removed_repositories: Vec<u64>,
622 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
623 pub settings_files: Vec<WorktreeSettingsFile>,
624 pub scan_id: u64,
625 pub completed_scan_id: u64,
626}
627
628pub struct LeftRoom {
629 pub room: proto::Room,
630 pub channel_id: Option<ChannelId>,
631 pub channel_members: Vec<UserId>,
632 pub left_projects: HashMap<ProjectId, LeftProject>,
633 pub canceled_calls_to_user_ids: Vec<UserId>,
634 pub deleted: bool,
635}
636
637pub struct RefreshedRoom {
638 pub room: proto::Room,
639 pub channel_id: Option<ChannelId>,
640 pub channel_members: Vec<UserId>,
641 pub stale_participant_user_ids: Vec<UserId>,
642 pub canceled_calls_to_user_ids: Vec<UserId>,
643}
644
645pub struct RefreshedChannelBuffer {
646 pub connection_ids: Vec<ConnectionId>,
647 pub collaborators: Vec<proto::Collaborator>,
648}
649
650pub struct Project {
651 pub collaborators: Vec<ProjectCollaborator>,
652 pub worktrees: BTreeMap<u64, Worktree>,
653 pub language_servers: Vec<proto::LanguageServer>,
654}
655
656pub struct ProjectCollaborator {
657 pub connection_id: ConnectionId,
658 pub user_id: UserId,
659 pub replica_id: ReplicaId,
660 pub is_host: bool,
661}
662
663impl ProjectCollaborator {
664 pub fn to_proto(&self) -> proto::Collaborator {
665 proto::Collaborator {
666 peer_id: Some(self.connection_id.into()),
667 replica_id: self.replica_id.0 as u32,
668 user_id: self.user_id.to_proto(),
669 }
670 }
671}
672
673#[derive(Debug)]
674pub struct LeftProject {
675 pub id: ProjectId,
676 pub host_user_id: UserId,
677 pub host_connection_id: ConnectionId,
678 pub connection_ids: Vec<ConnectionId>,
679}
680
681pub struct Worktree {
682 pub id: u64,
683 pub abs_path: String,
684 pub root_name: String,
685 pub visible: bool,
686 pub entries: Vec<proto::Entry>,
687 pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
688 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
689 pub settings_files: Vec<WorktreeSettingsFile>,
690 pub scan_id: u64,
691 pub completed_scan_id: u64,
692}
693
694#[derive(Debug)]
695pub struct WorktreeSettingsFile {
696 pub path: String,
697 pub content: String,
698}