diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index e667930cad2953d1379f9aa07389202f16ff2219..3066260bc431f65f18a68bbc7bd68442c18e0078 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2133,21 +2133,7 @@ impl Database { { let body = async { loop { - let tx = self.pool.begin().await?; - - // In Postgres, serializable transactions are opt-in - if let DatabaseBackend::Postgres = self.pool.get_database_backend() { - tx.execute(Statement::from_string( - DatabaseBackend::Postgres, - "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), - )) - .await?; - } - - let mut tx = Arc::new(Some(tx)); - let result = f(TransactionHandle(tx.clone())).await; - let tx = Arc::get_mut(&mut tx).unwrap().take().unwrap(); - + let (tx, result) = self.with_transaction(&f).await?; match result { Ok(result) => { tx.commit().await?; @@ -2196,21 +2182,7 @@ impl Database { { let body = async { loop { - let tx = self.pool.begin().await?; - - // In Postgres, serializable transactions are opt-in - if let DatabaseBackend::Postgres = self.pool.get_database_backend() { - tx.execute(Statement::from_string( - DatabaseBackend::Postgres, - "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), - )) - .await?; - } - - let mut tx = Arc::new(Some(tx)); - let result = f(TransactionHandle(tx.clone())).await; - let tx = Arc::get_mut(&mut tx).unwrap().take().unwrap(); - + let (tx, result) = self.with_transaction(&f).await?; match result { Ok((room_id, data)) => { let lock = self.rooms.entry(room_id).or_default().clone(); @@ -2257,6 +2229,31 @@ impl Database { body.await } } + + async fn with_transaction(&self, f: &F) -> Result<(DatabaseTransaction, Result)> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let tx = self.pool.begin().await?; + + // In Postgres, serializable transactions are opt-in + if let DatabaseBackend::Postgres = self.pool.get_database_backend() { + tx.execute(Statement::from_string( + DatabaseBackend::Postgres, + "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), + )) + .await?; + } + + let mut tx = Arc::new(Some(tx)); + let result = f(TransactionHandle(tx.clone())).await; + let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else { + return Err(anyhow!("couldn't complete transaction because it's still in use"))?; + }; + + Ok((tx, result)) + } } struct TransactionHandle(Arc>);