Remove duplication between `transaction` and `room_transaction`

Antonio Scandurra created

Change summary

crates/collab/src/db.rs | 57 ++++++++++++++++++++----------------------
1 file changed, 27 insertions(+), 30 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -2133,21 +2133,7 @@ impl Database {
     {
         let body = async {
             loop {
-                let tx = self.pool.begin().await?;
-
-                // In Postgres, serializable transactions are opt-in
-                if let DatabaseBackend::Postgres = self.pool.get_database_backend() {
-                    tx.execute(Statement::from_string(
-                        DatabaseBackend::Postgres,
-                        "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
-                    ))
-                    .await?;
-                }
-
-                let mut tx = Arc::new(Some(tx));
-                let result = f(TransactionHandle(tx.clone())).await;
-                let tx = Arc::get_mut(&mut tx).unwrap().take().unwrap();
-
+                let (tx, result) = self.with_transaction(&f).await?;
                 match result {
                     Ok(result) => {
                         tx.commit().await?;
@@ -2196,21 +2182,7 @@ impl Database {
     {
         let body = async {
             loop {
-                let tx = self.pool.begin().await?;
-
-                // In Postgres, serializable transactions are opt-in
-                if let DatabaseBackend::Postgres = self.pool.get_database_backend() {
-                    tx.execute(Statement::from_string(
-                        DatabaseBackend::Postgres,
-                        "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
-                    ))
-                    .await?;
-                }
-
-                let mut tx = Arc::new(Some(tx));
-                let result = f(TransactionHandle(tx.clone())).await;
-                let tx = Arc::get_mut(&mut tx).unwrap().take().unwrap();
-
+                let (tx, result) = self.with_transaction(&f).await?;
                 match result {
                     Ok((room_id, data)) => {
                         let lock = self.rooms.entry(room_id).or_default().clone();
@@ -2257,6 +2229,31 @@ impl Database {
             body.await
         }
     }
+
+    async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
+    where
+        F: Send + Fn(TransactionHandle) -> Fut,
+        Fut: Send + Future<Output = Result<T>>,
+    {
+        let tx = self.pool.begin().await?;
+
+        // In Postgres, serializable transactions are opt-in
+        if let DatabaseBackend::Postgres = self.pool.get_database_backend() {
+            tx.execute(Statement::from_string(
+                DatabaseBackend::Postgres,
+                "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
+            ))
+            .await?;
+        }
+
+        let mut tx = Arc::new(Some(tx));
+        let result = f(TransactionHandle(tx.clone())).await;
+        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
+            return Err(anyhow!("couldn't complete transaction because it's still in use"))?;
+        };
+
+        Ok((tx, result))
+    }
 }
 
 struct TransactionHandle(Arc<Option<DatabaseTransaction>>);