Retry transactions if there's a serialization failure during commit

Antonio Scandurra created

Change summary

crates/collab/src/db.rs | 163 ++++++++++++++++++++++--------------------
1 file changed, 87 insertions(+), 76 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -2131,47 +2131,30 @@ impl Database {
         F: Send + Fn(TransactionHandle) -> Fut,
         Fut: Send + Future<Output = Result<T>>,
     {
-        let body = async {
-            loop {
-                let (tx, result) = self.with_transaction(&f).await?;
-                match result {
-                    Ok(result) => {
-                        tx.commit().await?;
-                        return Ok(result);
-                    }
-                    Err(error) => {
-                        tx.rollback().await?;
-                        match error {
-                            Error::Database(
-                                DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
-                                | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
-                            ) if error
-                                .as_database_error()
-                                .and_then(|error| error.code())
-                                .as_deref()
-                                == Some("40001") =>
-                            {
+        loop {
+            let (tx, result) = self.run(self.with_transaction(&f)).await?;
+            match result {
+                Ok(result) => {
+                    match self.run(async move { Ok(tx.commit().await?) }).await {
+                        Ok(()) => return Ok(result),
+                        Err(error) => {
+                            if is_serialization_error(&error) {
                                 // Retry (don't break the loop)
+                            } else {
+                                return Err(error);
                             }
-                            error @ _ => return Err(error),
                         }
                     }
                 }
+                Err(error) => {
+                    self.run(tx.rollback()).await?;
+                    if is_serialization_error(&error) {
+                        // Retry (don't break the loop)
+                    } else {
+                        return Err(error);
+                    }
+                }
             }
-        };
-
-        #[cfg(test)]
-        {
-            if let Some(background) = self.background.as_ref() {
-                background.simulate_random_delay().await;
-            }
-
-            self.runtime.as_ref().unwrap().block_on(body)
-        }
-
-        #[cfg(not(test))]
-        {
-            body.await
         }
     }
 
@@ -2180,53 +2163,38 @@ impl Database {
         F: Send + Fn(TransactionHandle) -> Fut,
         Fut: Send + Future<Output = Result<(RoomId, T)>>,
     {
-        let body = async {
-            loop {
-                let (tx, result) = self.with_transaction(&f).await?;
-                match result {
-                    Ok((room_id, data)) => {
-                        let lock = self.rooms.entry(room_id).or_default().clone();
-                        let _guard = lock.lock_owned().await;
-                        tx.commit().await?;
-                        return Ok(RoomGuard {
-                            data,
-                            _guard,
-                            _not_send: PhantomData,
-                        });
-                    }
-                    Err(error) => {
-                        tx.rollback().await?;
-                        match error {
-                            Error::Database(
-                                DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
-                                | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
-                            ) if error
-                                .as_database_error()
-                                .and_then(|error| error.code())
-                                .as_deref()
-                                == Some("40001") =>
-                            {
+        loop {
+            let (tx, result) = self.run(self.with_transaction(&f)).await?;
+            match result {
+                Ok((room_id, data)) => {
+                    let lock = self.rooms.entry(room_id).or_default().clone();
+                    let _guard = lock.lock_owned().await;
+                    match self.run(async move { Ok(tx.commit().await?) }).await {
+                        Ok(()) => {
+                            return Ok(RoomGuard {
+                                data,
+                                _guard,
+                                _not_send: PhantomData,
+                            });
+                        }
+                        Err(error) => {
+                            if is_serialization_error(&error) {
                                 // Retry (don't break the loop)
+                            } else {
+                                return Err(error);
                             }
-                            error @ _ => return Err(error),
                         }
                     }
                 }
+                Err(error) => {
+                    self.run(tx.rollback()).await?;
+                    if is_serialization_error(&error) {
+                        // Retry (don't break the loop)
+                    } else {
+                        return Err(error);
+                    }
+                }
             }
-        };
-
-        #[cfg(test)]
-        {
-            if let Some(background) = self.background.as_ref() {
-                background.simulate_random_delay().await;
-            }
-
-            self.runtime.as_ref().unwrap().block_on(body)
-        }
-
-        #[cfg(not(test))]
-        {
-            body.await
         }
     }
 
@@ -2254,6 +2222,49 @@ impl Database {
 
         Ok((tx, result))
     }
+
+    async fn run<F, T>(&self, future: F) -> T
+    where
+        F: Future<Output = T>,
+    {
+        #[cfg(test)]
+        {
+            if let Some(background) = self.background.as_ref() {
+                background.simulate_random_delay().await;
+            }
+
+            let result = self.runtime.as_ref().unwrap().block_on(future);
+
+            if let Some(background) = self.background.as_ref() {
+                background.simulate_random_delay().await;
+            }
+
+            result
+        }
+
+        #[cfg(not(test))]
+        {
+            future.await
+        }
+    }
+}
+
+fn is_serialization_error(error: &Error) -> bool {
+    const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
+    match error {
+        Error::Database(
+            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
+            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
+        ) if error
+            .as_database_error()
+            .and_then(|error| error.code())
+            .as_deref()
+            == Some(SERIALIZATION_FAILURE_CODE) =>
+        {
+            true
+        }
+        _ => false,
+    }
 }
 
 struct TransactionHandle(Arc<Option<DatabaseTransaction>>);