WIP: Manually rollback transactions to avoid spurious savepoint failure

Antonio Scandurra and Max Brunsfeld created

TODO:
- Avoid unwrapping transaction after f(tx)
- Remove duplication between `transaction` and `room_transaction`
- Introduce random delay before and after committing a transaction
- Run lots of randomized tests
- Investigate diverging diagnostic summaries

Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

crates/collab/src/db.rs  | 296 ++++++++++++++++++-----------------------
crates/collab/src/rpc.rs |   4 
2 files changed, 129 insertions(+), 171 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -106,10 +106,10 @@ impl Database {
     }
 
     pub async fn clear_stale_data(&self) -> Result<()> {
-        self.transact(|tx| async {
+        self.transaction(|tx| async move {
             project_collaborator::Entity::delete_many()
                 .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch))
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
             room_participant::Entity::delete_many()
                 .filter(
@@ -117,11 +117,11 @@ impl Database {
                         .ne(self.epoch)
                         .or(room_participant::Column::CallingConnectionEpoch.ne(self.epoch)),
                 )
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
             project::Entity::delete_many()
                 .filter(project::Column::HostConnectionEpoch.ne(self.epoch))
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
             room::Entity::delete_many()
                 .filter(
@@ -133,9 +133,8 @@ impl Database {
                             .to_owned(),
                     ),
                 )
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
-            tx.commit().await?;
             Ok(())
         })
         .await
@@ -149,7 +148,8 @@ impl Database {
         admin: bool,
         params: NewUserParams,
     ) -> Result<NewUserResult> {
-        self.transact(|tx| async {
+        self.transaction(|tx| async {
+            let tx = tx;
             let user = user::Entity::insert(user::ActiveModel {
                 email_address: ActiveValue::set(Some(email_address.into())),
                 github_login: ActiveValue::set(params.github_login.clone()),
@@ -163,11 +163,9 @@ impl Database {
                     .update_column(user::Column::GithubLogin)
                     .to_owned(),
             )
-            .exec_with_returning(&tx)
+            .exec_with_returning(&*tx)
             .await?;
 
-            tx.commit().await?;
-
             Ok(NewUserResult {
                 user_id: user.id,
                 metrics_id: user.metrics_id.to_string(),
@@ -179,16 +177,16 @@ impl Database {
     }
 
     pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<user::Model>> {
-        self.transact(|tx| async move { Ok(user::Entity::find_by_id(id).one(&tx).await?) })
+        self.transaction(|tx| async move { Ok(user::Entity::find_by_id(id).one(&*tx).await?) })
             .await
     }
 
     pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
-        self.transact(|tx| async {
+        self.transaction(|tx| async {
             let tx = tx;
             Ok(user::Entity::find()
                 .filter(user::Column::Id.is_in(ids.iter().copied()))
-                .all(&tx)
+                .all(&*tx)
                 .await?)
         })
         .await
@@ -199,32 +197,32 @@ impl Database {
         github_login: &str,
         github_user_id: Option<i32>,
     ) -> Result<Option<User>> {
-        self.transact(|tx| async {
-            let tx = tx;
+        self.transaction(|tx| async move {
+            let tx = &*tx;
             if let Some(github_user_id) = github_user_id {
                 if let Some(user_by_github_user_id) = user::Entity::find()
                     .filter(user::Column::GithubUserId.eq(github_user_id))
-                    .one(&tx)
+                    .one(tx)
                     .await?
                 {
                     let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
                     user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
-                    Ok(Some(user_by_github_user_id.update(&tx).await?))
+                    Ok(Some(user_by_github_user_id.update(tx).await?))
                 } else if let Some(user_by_github_login) = user::Entity::find()
                     .filter(user::Column::GithubLogin.eq(github_login))
-                    .one(&tx)
+                    .one(tx)
                     .await?
                 {
                     let mut user_by_github_login = user_by_github_login.into_active_model();
                     user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
-                    Ok(Some(user_by_github_login.update(&tx).await?))
+                    Ok(Some(user_by_github_login.update(tx).await?))
                 } else {
                     Ok(None)
                 }
             } else {
                 Ok(user::Entity::find()
                     .filter(user::Column::GithubLogin.eq(github_login))
-                    .one(&tx)
+                    .one(tx)
                     .await?)
             }
         })
@@ -232,12 +230,12 @@ impl Database {
     }
 
     pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             Ok(user::Entity::find()
                 .order_by_asc(user::Column::GithubLogin)
                 .limit(limit as u64)
                 .offset(page as u64 * limit as u64)
-                .all(&tx)
+                .all(&*tx)
                 .await?)
         })
         .await
@@ -247,7 +245,7 @@ impl Database {
         &self,
         invited_by_another_user: bool,
     ) -> Result<Vec<User>> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             Ok(user::Entity::find()
                 .filter(
                     user::Column::InviteCount
@@ -258,7 +256,7 @@ impl Database {
                             user::Column::InviterId.is_null()
                         }),
                 )
-                .all(&tx)
+                .all(&*tx)
                 .await?)
         })
         .await
@@ -270,12 +268,12 @@ impl Database {
             MetricsId,
         }
 
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let metrics_id: Uuid = user::Entity::find_by_id(id)
                 .select_only()
                 .column(user::Column::MetricsId)
                 .into_values::<_, QueryAs>()
-                .one(&tx)
+                .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("could not find user"))?;
             Ok(metrics_id.to_string())
@@ -284,45 +282,42 @@ impl Database {
     }
 
     pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             user::Entity::update_many()
                 .filter(user::Column::Id.eq(id))
                 .set(user::ActiveModel {
                     admin: ActiveValue::set(is_admin),
                     ..Default::default()
                 })
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
-            tx.commit().await?;
             Ok(())
         })
         .await
     }
 
     pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             user::Entity::update_many()
                 .filter(user::Column::Id.eq(id))
                 .set(user::ActiveModel {
                     connected_once: ActiveValue::set(connected_once),
                     ..Default::default()
                 })
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
-            tx.commit().await?;
             Ok(())
         })
         .await
     }
 
     pub async fn destroy_user(&self, id: UserId) -> Result<()> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             access_token::Entity::delete_many()
                 .filter(access_token::Column::UserId.eq(id))
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
-            user::Entity::delete_by_id(id).exec(&tx).await?;
-            tx.commit().await?;
+            user::Entity::delete_by_id(id).exec(&*tx).await?;
             Ok(())
         })
         .await
@@ -342,7 +337,7 @@ impl Database {
             user_b_busy: bool,
         }
 
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let user_a_participant = Alias::new("user_a_participant");
             let user_b_participant = Alias::new("user_b_participant");
             let mut db_contacts = contact::Entity::find()
@@ -372,7 +367,7 @@ impl Database {
                     user_b_participant,
                 )
                 .into_model::<ContactWithUserBusyStatuses>()
-                .stream(&tx)
+                .stream(&*tx)
                 .await?;
 
             let mut contacts = Vec::new();
@@ -421,10 +416,10 @@ impl Database {
     }
 
     pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let participant = room_participant::Entity::find()
                 .filter(room_participant::Column::UserId.eq(user_id))
-                .one(&tx)
+                .one(&*tx)
                 .await?;
             Ok(participant.is_some())
         })
@@ -432,7 +427,7 @@ impl Database {
     }
 
     pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let (id_a, id_b) = if user_id_1 < user_id_2 {
                 (user_id_1, user_id_2)
             } else {
@@ -446,7 +441,7 @@ impl Database {
                         .and(contact::Column::UserIdB.eq(id_b))
                         .and(contact::Column::Accepted.eq(true)),
                 )
-                .one(&tx)
+                .one(&*tx)
                 .await?
                 .is_some())
         })
@@ -454,7 +449,7 @@ impl Database {
     }
 
     pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
                 (sender_id, receiver_id, true)
             } else {
@@ -487,11 +482,10 @@ impl Database {
                     )
                     .to_owned(),
             )
-            .exec_without_returning(&tx)
+            .exec_without_returning(&*tx)
             .await?;
 
             if rows_affected == 1 {
-                tx.commit().await?;
                 Ok(())
             } else {
                 Err(anyhow!("contact already requested"))?
@@ -501,7 +495,7 @@ impl Database {
     }
 
     pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let (id_a, id_b) = if responder_id < requester_id {
                 (responder_id, requester_id)
             } else {
@@ -514,11 +508,10 @@ impl Database {
                         .eq(id_a)
                         .and(contact::Column::UserIdB.eq(id_b)),
                 )
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
 
             if result.rows_affected == 1 {
-                tx.commit().await?;
                 Ok(())
             } else {
                 Err(anyhow!("no such contact"))?
@@ -532,7 +525,7 @@ impl Database {
         user_id: UserId,
         contact_user_id: UserId,
     ) -> Result<()> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
                 (user_id, contact_user_id, true)
             } else {
@@ -557,12 +550,11 @@ impl Database {
                                     .and(contact::Column::Accepted.eq(false))),
                         ),
                 )
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
             if result.rows_affected == 0 {
                 Err(anyhow!("no such contact request"))?
             } else {
-                tx.commit().await?;
                 Ok(())
             }
         })
@@ -575,7 +567,7 @@ impl Database {
         requester_id: UserId,
         accept: bool,
     ) -> Result<()> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let (id_a, id_b, a_to_b) = if responder_id < requester_id {
                 (responder_id, requester_id, false)
             } else {
@@ -594,7 +586,7 @@ impl Database {
                             .and(contact::Column::UserIdB.eq(id_b))
                             .and(contact::Column::AToB.eq(a_to_b)),
                     )
-                    .exec(&tx)
+                    .exec(&*tx)
                     .await?;
                 result.rows_affected
             } else {
@@ -606,14 +598,13 @@ impl Database {
                             .and(contact::Column::AToB.eq(a_to_b))
                             .and(contact::Column::Accepted.eq(false)),
                     )
-                    .exec(&tx)
+                    .exec(&*tx)
                     .await?;
 
                 result.rows_affected
             };
 
             if rows_affected == 1 {
-                tx.commit().await?;
                 Ok(())
             } else {
                 Err(anyhow!("no such contact request"))?
@@ -635,7 +626,7 @@ impl Database {
     }
 
     pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
-        self.transact(|tx| async {
+        self.transaction(|tx| async {
             let tx = tx;
             let like_string = Self::fuzzy_like_string(name_query);
             let query = "
@@ -652,7 +643,7 @@ impl Database {
                     query.into(),
                     vec![like_string.into(), name_query.into(), limit.into()],
                 ))
-                .all(&tx)
+                .all(&*tx)
                 .await?)
         })
         .await
@@ -661,7 +652,7 @@ impl Database {
     // signups
 
     pub async fn create_signup(&self, signup: &NewSignup) -> Result<()> {
-        self.transact(|tx| async {
+        self.transaction(|tx| async move {
             signup::Entity::insert(signup::ActiveModel {
                 email_address: ActiveValue::set(signup.email_address.clone()),
                 email_confirmation_code: ActiveValue::set(random_email_confirmation_code()),
@@ -681,16 +672,15 @@ impl Database {
                     .update_column(signup::Column::EmailAddress)
                     .to_owned(),
             )
-            .exec(&tx)
+            .exec(&*tx)
             .await?;
-            tx.commit().await?;
             Ok(())
         })
         .await
     }
 
     pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let query = "
                 SELECT
                     COUNT(*) as count,
@@ -711,7 +701,7 @@ impl Database {
                     query.into(),
                     vec![],
                 ))
-                .one(&tx)
+                .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("invalid result"))?,
             )
@@ -724,23 +714,23 @@ impl Database {
             .iter()
             .map(|s| s.email_address.as_str())
             .collect::<Vec<_>>();
-        self.transact(|tx| async {
+        self.transaction(|tx| async {
+            let tx = tx;
             signup::Entity::update_many()
                 .filter(signup::Column::EmailAddress.is_in(emails.iter().copied()))
                 .set(signup::ActiveModel {
                     email_confirmation_sent: ActiveValue::set(true),
                     ..Default::default()
                 })
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
-            tx.commit().await?;
             Ok(())
         })
         .await
     }
 
     pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             Ok(signup::Entity::find()
                 .select_only()
                 .column(signup::Column::EmailAddress)
@@ -755,7 +745,7 @@ impl Database {
                 .order_by_asc(signup::Column::CreatedAt)
                 .limit(count as u64)
                 .into_model()
-                .all(&tx)
+                .all(&*tx)
                 .await?)
         })
         .await
@@ -769,10 +759,10 @@ impl Database {
         email_address: &str,
         device_id: Option<&str>,
     ) -> Result<Invite> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let existing_user = user::Entity::find()
                 .filter(user::Column::EmailAddress.eq(email_address))
-                .one(&tx)
+                .one(&*tx)
                 .await?;
 
             if existing_user.is_some() {
@@ -785,7 +775,7 @@ impl Database {
                         .eq(code)
                         .and(user::Column::InviteCount.gt(0)),
                 )
-                .one(&tx)
+                .one(&*tx)
                 .await?
             {
                 Some(inviting_user) => inviting_user,
@@ -806,7 +796,7 @@ impl Database {
                     user::Column::InviteCount,
                     Expr::col(user::Column::InviteCount).sub(1),
                 )
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
 
             let signup = signup::Entity::insert(signup::ActiveModel {
@@ -826,9 +816,8 @@ impl Database {
                     .update_column(signup::Column::InvitingUserId)
                     .to_owned(),
             )
-            .exec_with_returning(&tx)
+            .exec_with_returning(&*tx)
             .await?;
-            tx.commit().await?;
 
             Ok(Invite {
                 email_address: signup.email_address,
@@ -843,7 +832,7 @@ impl Database {
         invite: &Invite,
         user: NewUserParams,
     ) -> Result<Option<NewUserResult>> {
-        self.transact(|tx| async {
+        self.transaction(|tx| async {
             let tx = tx;
             let signup = signup::Entity::find()
                 .filter(
@@ -854,7 +843,7 @@ impl Database {
                                 .eq(invite.email_confirmation_code.as_str()),
                         ),
                 )
-                .one(&tx)
+                .one(&*tx)
                 .await?
                 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 
@@ -881,12 +870,12 @@ impl Database {
                     ])
                     .to_owned(),
             )
-            .exec_with_returning(&tx)
+            .exec_with_returning(&*tx)
             .await?;
 
             let mut signup = signup.into_active_model();
             signup.user_id = ActiveValue::set(Some(user.id));
-            let signup = signup.update(&tx).await?;
+            let signup = signup.update(&*tx).await?;
 
             if let Some(inviting_user_id) = signup.inviting_user_id {
                 contact::Entity::insert(contact::ActiveModel {
@@ -898,11 +887,10 @@ impl Database {
                     ..Default::default()
                 })
                 .on_conflict(OnConflict::new().do_nothing().to_owned())
-                .exec_without_returning(&tx)
+                .exec_without_returning(&*tx)
                 .await?;
             }
 
-            tx.commit().await?;
             Ok(Some(NewUserResult {
                 user_id: user.id,
                 metrics_id: user.metrics_id.to_string(),
@@ -914,7 +902,7 @@ impl Database {
     }
 
     pub async fn set_invite_count_for_user(&self, id: UserId, count: i32) -> Result<()> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             if count > 0 {
                 user::Entity::update_many()
                     .filter(
@@ -926,7 +914,7 @@ impl Database {
                         invite_code: ActiveValue::set(Some(random_invite_code())),
                         ..Default::default()
                     })
-                    .exec(&tx)
+                    .exec(&*tx)
                     .await?;
             }
 
@@ -936,17 +924,16 @@ impl Database {
                     invite_count: ActiveValue::set(count),
                     ..Default::default()
                 })
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
-            tx.commit().await?;
             Ok(())
         })
         .await
     }
 
     pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, i32)>> {
-        self.transact(|tx| async move {
-            match user::Entity::find_by_id(id).one(&tx).await? {
+        self.transaction(|tx| async move {
+            match user::Entity::find_by_id(id).one(&*tx).await? {
                 Some(user) if user.invite_code.is_some() => {
                     Ok(Some((user.invite_code.unwrap(), user.invite_count)))
                 }
@@ -957,10 +944,10 @@ impl Database {
     }
 
     pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             user::Entity::find()
                 .filter(user::Column::InviteCode.eq(code))
-                .one(&tx)
+                .one(&*tx)
                 .await?
                 .ok_or_else(|| {
                     Error::Http(
@@ -978,14 +965,14 @@ impl Database {
         &self,
         user_id: UserId,
     ) -> Result<Option<proto::IncomingCall>> {
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             let pending_participant = room_participant::Entity::find()
                 .filter(
                     room_participant::Column::UserId
                         .eq(user_id)
                         .and(room_participant::Column::AnsweringConnectionId.is_null()),
                 )
-                .one(&tx)
+                .one(&*tx)
                 .await?;
 
             if let Some(pending_participant) = pending_participant {
@@ -1004,12 +991,12 @@ impl Database {
         connection_id: ConnectionId,
         live_kit_room: &str,
     ) -> Result<RoomGuard<proto::Room>> {
-        self.transact(|tx| async move {
+        self.room_transaction(|tx| async move {
             let room = room::ActiveModel {
                 live_kit_room: ActiveValue::set(live_kit_room.into()),
                 ..Default::default()
             }
-            .insert(&tx)
+            .insert(&*tx)
             .await?;
             let room_id = room.id;
 
@@ -1023,11 +1010,11 @@ impl Database {
                 calling_connection_epoch: ActiveValue::set(self.epoch),
                 ..Default::default()
             }
-            .insert(&tx)
+            .insert(&*tx)
             .await?;
 
             let room = self.get_room(room_id, &tx).await?;
-            self.commit_room_transaction(room_id, tx, room).await
+            Ok((room_id, room))
         })
         .await
     }
@@ -1040,7 +1027,7 @@ impl Database {
         called_user_id: UserId,
         initial_project_id: Option<ProjectId>,
     ) -> Result<RoomGuard<(proto::Room, proto::IncomingCall)>> {
-        self.transact(|tx| async move {
+        self.room_transaction(|tx| async move {
             room_participant::ActiveModel {
                 room_id: ActiveValue::set(room_id),
                 user_id: ActiveValue::set(called_user_id),
@@ -1050,14 +1037,13 @@ impl Database {
                 initial_project_id: ActiveValue::set(initial_project_id),
                 ..Default::default()
             }
-            .insert(&tx)
+            .insert(&*tx)
             .await?;
 
             let room = self.get_room(room_id, &tx).await?;
             let incoming_call = Self::build_incoming_call(&room, called_user_id)
                 .ok_or_else(|| anyhow!("failed to build incoming call"))?;
-            self.commit_room_transaction(room_id, tx, (room, incoming_call))
-                .await
+            Ok((room_id, (room, incoming_call)))
         })
         .await
     }
@@ -1067,17 +1053,17 @@ impl Database {
         room_id: RoomId,
         called_user_id: UserId,
     ) -> Result<RoomGuard<proto::Room>> {
-        self.transact(|tx| async move {
+        self.room_transaction(|tx| async move {
             room_participant::Entity::delete_many()
                 .filter(
                     room_participant::Column::RoomId
                         .eq(room_id)
                         .and(room_participant::Column::UserId.eq(called_user_id)),
                 )
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
             let room = self.get_room(room_id, &tx).await?;
-            self.commit_room_transaction(room_id, tx, room).await
+            Ok((room_id, room))
         })
         .await
     }
@@ -1087,14 +1073,14 @@ impl Database {
         expected_room_id: Option<RoomId>,
         user_id: UserId,
     ) -> Result<RoomGuard<proto::Room>> {
-        self.transact(|tx| async move {
+        self.room_transaction(|tx| async move {
             let participant = room_participant::Entity::find()
                 .filter(
                     room_participant::Column::UserId
                         .eq(user_id)
                         .and(room_participant::Column::AnsweringConnectionId.is_null()),
                 )
-                .one(&tx)
+                .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("could not decline call"))?;
             let room_id = participant.room_id;
@@ -1104,11 +1090,11 @@ impl Database {
             }
 
             room_participant::Entity::delete(participant.into_active_model())
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
 
             let room = self.get_room(room_id, &tx).await?;
-            self.commit_room_transaction(room_id, tx, room).await
+            Ok((room_id, room))
         })
         .await
     }
@@ -1119,7 +1105,7 @@ impl Database {
         calling_connection_id: ConnectionId,
         called_user_id: UserId,
     ) -> Result<RoomGuard<proto::Room>> {
-        self.transact(|tx| async move {
+        self.room_transaction(|tx| async move {
             let participant = room_participant::Entity::find()
                 .filter(
                     room_participant::Column::UserId
@@ -1130,7 +1116,7 @@ impl Database {
                         )
                         .and(room_participant::Column::AnsweringConnectionId.is_null()),
                 )
-                .one(&tx)
+                .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("could not cancel call"))?;
             let room_id = participant.room_id;
@@ -1139,11 +1125,11 @@ impl Database {
             }
 
             room_participant::Entity::delete(participant.into_active_model())
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
 
             let room = self.get_room(room_id, &tx).await?;
-            self.commit_room_transaction(room_id, tx, room).await
+            Ok((room_id, room))
         })
         .await
     }
@@ -1154,7 +1140,7 @@ impl Database {
         user_id: UserId,
         connection_id: ConnectionId,
     ) -> Result<RoomGuard<proto::Room>> {
-        self.transact(|tx| async move {
+        self.room_transaction(|tx| async move {
             let result = room_participant::Entity::update_many()
                 .filter(
                     room_participant::Column::RoomId
@@ -1167,33 +1153,30 @@ impl Database {
                     answering_connection_epoch: ActiveValue::set(Some(self.epoch)),
                     ..Default::default()
                 })
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
             if result.rows_affected == 0 {
                 Err(anyhow!("room does not exist or was already joined"))?
             } else {
                 let room = self.get_room(room_id, &tx).await?;
-                self.commit_room_transaction(room_id, tx, room).await
+                Ok((room_id, room))
             }
         })
         .await
     }
 
-    pub async fn leave_room(
-        &self,
-        connection_id: ConnectionId,
-    ) -> Result<Option<RoomGuard<LeftRoom>>> {
-        self.transact(|tx| async move {
+    pub async fn leave_room(&self, connection_id: ConnectionId) -> Result<RoomGuard<LeftRoom>> {
+        self.room_transaction(|tx| async move {
             let leaving_participant = room_participant::Entity::find()
                 .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
-                .one(&tx)
+                .one(&*tx)
                 .await?;
 
             if let Some(leaving_participant) = leaving_participant {
                 // Leave room.
                 let room_id = leaving_participant.room_id;
                 room_participant::Entity::delete_by_id(leaving_participant.id)
-                    .exec(&tx)
+                    .exec(&*tx)
                     .await?;
 
                 // Cancel pending calls initiated by the leaving user.
@@ -1203,14 +1186,14 @@ impl Database {
                             .eq(connection_id.0)
                             .and(room_participant::Column::AnsweringConnectionId.is_null()),
                     )
-                    .all(&tx)
+                    .all(&*tx)
                     .await?;
                 room_participant::Entity::delete_many()
                     .filter(
                         room_participant::Column::Id
                             .is_in(called_participants.iter().map(|participant| participant.id)),
                     )
-                    .exec(&tx)
+                    .exec(&*tx)
                     .await?;
                 let canceled_calls_to_user_ids = called_participants
                     .into_iter()
@@ -1230,12 +1213,12 @@ impl Database {
                     )
                     .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0))
                     .into_values::<_, QueryProjectIds>()
-                    .all(&tx)
+                    .all(&*tx)
                     .await?;
                 let mut left_projects = HashMap::default();
                 let mut collaborators = project_collaborator::Entity::find()
                     .filter(project_collaborator::Column::ProjectId.is_in(project_ids))
-                    .stream(&tx)
+                    .stream(&*tx)
                     .await?;
                 while let Some(collaborator) = collaborators.next().await {
                     let collaborator = collaborator?;
@@ -1266,7 +1249,7 @@ impl Database {
                 // Leave projects.
                 project_collaborator::Entity::delete_many()
                     .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0))
-                    .exec(&tx)
+                    .exec(&*tx)
                     .await?;
 
                 // Unshare projects.
@@ -1276,33 +1259,27 @@ impl Database {
                             .eq(room_id)
                             .and(project::Column::HostConnectionId.eq(connection_id.0)),
                     )
-                    .exec(&tx)
+                    .exec(&*tx)
                     .await?;
 
                 let room = self.get_room(room_id, &tx).await?;
                 if room.participants.is_empty() {
-                    room::Entity::delete_by_id(room_id).exec(&tx).await?;
+                    room::Entity::delete_by_id(room_id).exec(&*tx).await?;
                 }
 
-                let left_room = self
-                    .commit_room_transaction(
-                        room_id,
-                        tx,
-                        LeftRoom {
-                            room,
-                            left_projects,
-                            canceled_calls_to_user_ids,
-                        },
-                    )
-                    .await?;
+                let left_room = LeftRoom {
+                    room,
+                    left_projects,
+                    canceled_calls_to_user_ids,
+                };
 
                 if left_room.room.participants.is_empty() {
                     self.rooms.remove(&room_id);
                 }
 
-                Ok(Some(left_room))
+                Ok((room_id, left_room))
             } else {
-                Ok(None)
+                Err(anyhow!("could not leave room"))?
             }
         })
         .await
@@ -1314,8 +1291,8 @@ impl Database {
         connection_id: ConnectionId,
         location: proto::ParticipantLocation,
     ) -> Result<RoomGuard<proto::Room>> {
-        self.transact(|tx| async {
-            let mut tx = tx;
+        self.room_transaction(|tx| async {
+            let tx = tx;
             let location_kind;
             let location_project_id;
             match location
@@ -1348,12 +1325,12 @@ impl Database {
                     location_project_id: ActiveValue::set(location_project_id),
                     ..Default::default()
                 })
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
 
             if result.rows_affected == 1 {
-                let room = self.get_room(room_id, &mut tx).await?;
-                self.commit_room_transaction(room_id, tx, room).await
+                let room = self.get_room(room_id, &tx).await?;
+                Ok((room_id, room))
             } else {
                 Err(anyhow!("could not update room participant location"))?
             }
@@ -1478,22 +1455,6 @@ impl Database {
         })
     }
 
-    async fn commit_room_transaction<T>(
-        &self,
-        room_id: RoomId,
-        tx: DatabaseTransaction,
-        data: T,
-    ) -> Result<RoomGuard<T>> {
-        let lock = self.rooms.entry(room_id).or_default().clone();
-        let _guard = lock.lock_owned().await;
-        tx.commit().await?;
-        Ok(RoomGuard {
-            data,
-            _guard,
-            _not_send: PhantomData,
-        })
-    }
-
     // projects
 
     pub async fn project_count_excluding_admins(&self) -> Result<usize> {
@@ -1502,14 +1463,14 @@ impl Database {
             Count,
         }
 
-        self.transact(|tx| async move {
+        self.transaction(|tx| async move {
             Ok(project::Entity::find()
                 .select_only()
                 .column_as(project::Column::Id.count(), QueryAs::Count)
                 .inner_join(user::Entity)
                 .filter(user::Column::Admin.eq(false))
                 .into_values::<_, QueryAs>()
-                .one(&tx)
+                .one(&*tx)
                 .await?
                 .unwrap_or(0) as usize)
         })
@@ -1522,10 +1483,10 @@ impl Database {
         connection_id: ConnectionId,
         worktrees: &[proto::WorktreeMetadata],
     ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
-        self.transact(|tx| async move {
+        self.room_transaction(|tx| async move {
             let participant = room_participant::Entity::find()
                 .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
-                .one(&tx)
+                .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("could not find participant"))?;
             if participant.room_id != room_id {
@@ -1539,7 +1500,7 @@ impl Database {
                 host_connection_epoch: ActiveValue::set(self.epoch),
                 ..Default::default()
             }
-            .insert(&tx)
+            .insert(&*tx)
             .await?;
 
             if !worktrees.is_empty() {
@@ -1554,7 +1515,7 @@ impl Database {
                         is_complete: ActiveValue::set(false),
                     }
                 }))
-                .exec(&tx)
+                .exec(&*tx)
                 .await?;
             }
 
@@ -1567,12 +1528,11 @@ impl Database {
                 is_host: ActiveValue::set(true),
                 ..Default::default()
             }
-            .insert(&tx)
+            .insert(&*tx)
             .await?;
 
             let room = self.get_room(room_id, &tx).await?;
-            self.commit_room_transaction(room_id, tx, (project.id, room))
-                .await
+            Ok((room_id, (project.id, room)))
         })
         .await
     }

crates/collab/src/rpc.rs 🔗

@@ -1854,9 +1854,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
     let live_kit_room;
     let delete_live_kit_room;
     {
-        let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? else {
-            return Err(anyhow!("no room to leave"))?;
-        };
+        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
         contacts_to_update.insert(session.user_id);
 
         for project in left_room.left_projects.values() {