diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index bc074e30df5ac6bc4d80fe62e42ee6cd78ed6387..dfd1d7e65a1d1467aac38d8694d72fd981a6c1da 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2131,31 +2131,35 @@ impl Database { F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { - loop { - let (tx, result) = self.run(self.with_transaction(&f)).await?; - match result { - Ok(result) => { - match self.run(async move { Ok(tx.commit().await?) }).await { - Ok(()) => return Ok(result), - Err(error) => { - if is_serialization_error(&error) { - // Retry (don't break the loop) - } else { - return Err(error); + let body = async { + loop { + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(result) => { + match tx.commit().await.map_err(Into::into) { + Ok(()) => return Ok(result), + Err(error) => { + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } } } } - } - Err(error) => { - self.run(tx.rollback()).await?; - if is_serialization_error(&error) { - // Retry (don't break the loop) - } else { - return Err(error); + Err(error) => { + tx.rollback().await?; + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } } } } - } + }; + + self.run(body).await } async fn room_transaction(&self, f: F) -> Result> @@ -2163,39 +2167,43 @@ impl Database { F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { - loop { - let (tx, result) = self.run(self.with_transaction(&f)).await?; - match result { - Ok((room_id, data)) => { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - match self.run(async move { Ok(tx.commit().await?) }).await { - Ok(()) => { - return Ok(RoomGuard { - data, - _guard, - _not_send: PhantomData, - }); - } - Err(error) => { - if is_serialization_error(&error) { - // Retry (don't break the loop) - } else { - return Err(error); + let body = async { + loop { + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok((room_id, data)) => { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + match tx.commit().await.map_err(Into::into) { + Ok(()) => { + return Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }); + } + Err(error) => { + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } } } } - } - Err(error) => { - self.run(tx.rollback()).await?; - if is_serialization_error(&error) { - // Retry (don't break the loop) - } else { - return Err(error); + Err(error) => { + tx.rollback().await?; + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } } } } - } + }; + + self.run(body).await } async fn with_transaction(&self, f: &F) -> Result<(DatabaseTransaction, Result)> @@ -2233,13 +2241,7 @@ impl Database { background.simulate_random_delay().await; } - let result = self.runtime.as_ref().unwrap().block_on(future); - - if let Some(background) = self.background.as_ref() { - background.simulate_random_delay().await; - } - - result + self.runtime.as_ref().unwrap().block_on(future) } #[cfg(not(test))] diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 73f450b8336c092757e3ca872ce763ab0a405558..4ff372efbe95d4a80d646ddabd87bc9f6267378b 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -5672,7 +5672,13 @@ impl TestServer { async fn start(background: Arc) -> Self { static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); - let test_db = TestDb::sqlite(background.clone()); + let use_postgres = env::var("USE_POSTGRES").ok(); + let use_postgres = use_postgres.as_deref(); + let test_db = if use_postgres == Some("true") || use_postgres == Some("1") { + TestDb::postgres(background.clone()) + } else { + TestDb::sqlite(background.clone()) + }; let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); let live_kit_server = live_kit_client::TestServer::create( format!("http://livekit.{}.test", live_kit_server_id),