Merge pull request #2296 from zed-industries/tx-serialization-retry-delay

Max Brunsfeld created

Introduce a delay before retrying a transaction after a serialization failure

Change summary

crates/collab/Cargo.toml      |   2 
crates/collab/src/bin/seed.rs |   4 
crates/collab/src/db.rs       | 135 ++++++++++++++++++++----------------
crates/collab/src/lib.rs      |   3 
crates/collab/src/main.rs     |   2 
5 files changed, 79 insertions(+), 67 deletions(-)

Detailed changes

crates/collab/Cargo.toml 🔗

@@ -31,6 +31,7 @@ futures = "0.3"
 hyper = "0.14"
 lazy_static = "1.4"
 lipsum = { version = "0.8", optional = true }
+log = { version = "0.4.16", features = ["kv_unstable_serde"] }
 nanoid = "0.4"
 parking_lot = "0.11.1"
 prometheus = "0.13"
@@ -74,7 +75,6 @@ workspace = { path = "../workspace", features = ["test-support"] }
 
 ctor = "0.1"
 env_logger = "0.9"
-log = { version = "0.4.16", features = ["kv_unstable_serde"] }
 util = { path = "../util" }
 lazy_static = "1.4"
 sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-sqlite"] }

crates/collab/src/bin/seed.rs 🔗

@@ -1,4 +1,4 @@
-use collab::db;
+use collab::{db, executor::Executor};
 use db::{ConnectOptions, Database};
 use serde::{de::DeserializeOwned, Deserialize};
 use std::fmt::Write;
@@ -13,7 +13,7 @@ struct GitHubUser {
 #[tokio::main]
 async fn main() {
     let database_url = std::env::var("DATABASE_URL").expect("missing DATABASE_URL env var");
-    let db = Database::new(ConnectOptions::new(database_url))
+    let db = Database::new(ConnectOptions::new(database_url), Executor::Production)
         .await
         .expect("failed to connect to postgres database");
     let github_token = std::env::var("GITHUB_TOKEN").expect("missing GITHUB_TOKEN env var");

crates/collab/src/db.rs 🔗

@@ -15,6 +15,7 @@ mod worktree;
 mod worktree_diagnostic_summary;
 mod worktree_entry;
 
+use crate::executor::Executor;
 use crate::{Error, Result};
 use anyhow::anyhow;
 use collections::{BTreeMap, HashMap, HashSet};
@@ -22,6 +23,8 @@ pub use contact::Contact;
 use dashmap::DashMap;
 use futures::StreamExt;
 use hyper::StatusCode;
+use rand::prelude::StdRng;
+use rand::{Rng, SeedableRng};
 use rpc::{proto, ConnectionId};
 use sea_orm::Condition;
 pub use sea_orm::ConnectOptions;
@@ -46,20 +49,20 @@ pub struct Database {
     options: ConnectOptions,
     pool: DatabaseConnection,
     rooms: DashMap<RoomId, Arc<Mutex<()>>>,
-    #[cfg(test)]
-    background: Option<std::sync::Arc<gpui::executor::Background>>,
+    rng: Mutex<StdRng>,
+    executor: Executor,
     #[cfg(test)]
     runtime: Option<tokio::runtime::Runtime>,
 }
 
 impl Database {
-    pub async fn new(options: ConnectOptions) -> Result<Self> {
+    pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
         Ok(Self {
             options: options.clone(),
             pool: sea_orm::Database::connect(options).await?,
             rooms: DashMap::with_capacity(16384),
-            #[cfg(test)]
-            background: None,
+            rng: Mutex::new(StdRng::seed_from_u64(0)),
+            executor,
             #[cfg(test)]
             runtime: None,
         })
@@ -2805,30 +2808,26 @@ impl Database {
         Fut: Send + Future<Output = Result<T>>,
     {
         let body = async {
+            let mut i = 0;
             loop {
                 let (tx, result) = self.with_transaction(&f).await?;
                 match result {
-                    Ok(result) => {
-                        match tx.commit().await.map_err(Into::into) {
-                            Ok(()) => return Ok(result),
-                            Err(error) => {
-                                if is_serialization_error(&error) {
-                                    // Retry (don't break the loop)
-                                } else {
-                                    return Err(error);
-                                }
+                    Ok(result) => match tx.commit().await.map_err(Into::into) {
+                        Ok(()) => return Ok(result),
+                        Err(error) => {
+                            if !self.retry_on_serialization_error(&error, i).await {
+                                return Err(error);
                             }
                         }
-                    }
+                    },
                     Err(error) => {
                         tx.rollback().await?;
-                        if is_serialization_error(&error) {
-                            // Retry (don't break the loop)
-                        } else {
+                        if !self.retry_on_serialization_error(&error, i).await {
                             return Err(error);
                         }
                     }
                 }
+                i += 1;
             }
         };
 
@@ -2841,6 +2840,7 @@ impl Database {
         Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
     {
         let body = async {
+            let mut i = 0;
             loop {
                 let (tx, result) = self.with_transaction(&f).await?;
                 match result {
@@ -2856,35 +2856,28 @@ impl Database {
                                 }));
                             }
                             Err(error) => {
-                                if is_serialization_error(&error) {
-                                    // Retry (don't break the loop)
-                                } else {
+                                if !self.retry_on_serialization_error(&error, i).await {
                                     return Err(error);
                                 }
                             }
                         }
                     }
-                    Ok(None) => {
-                        match tx.commit().await.map_err(Into::into) {
-                            Ok(()) => return Ok(None),
-                            Err(error) => {
-                                if is_serialization_error(&error) {
-                                    // Retry (don't break the loop)
-                                } else {
-                                    return Err(error);
-                                }
+                    Ok(None) => match tx.commit().await.map_err(Into::into) {
+                        Ok(()) => return Ok(None),
+                        Err(error) => {
+                            if !self.retry_on_serialization_error(&error, i).await {
+                                return Err(error);
                             }
                         }
-                    }
+                    },
                     Err(error) => {
                         tx.rollback().await?;
-                        if is_serialization_error(&error) {
-                            // Retry (don't break the loop)
-                        } else {
+                        if !self.retry_on_serialization_error(&error, i).await {
                             return Err(error);
                         }
                     }
                 }
+                i += 1;
             }
         };
 
@@ -2897,38 +2890,34 @@ impl Database {
         Fut: Send + Future<Output = Result<T>>,
     {
         let body = async {
+            let mut i = 0;
             loop {
                 let lock = self.rooms.entry(room_id).or_default().clone();
                 let _guard = lock.lock_owned().await;
                 let (tx, result) = self.with_transaction(&f).await?;
                 match result {
-                    Ok(data) => {
-                        match tx.commit().await.map_err(Into::into) {
-                            Ok(()) => {
-                                return Ok(RoomGuard {
-                                    data,
-                                    _guard,
-                                    _not_send: PhantomData,
-                                });
-                            }
-                            Err(error) => {
-                                if is_serialization_error(&error) {
-                                    // Retry (don't break the loop)
-                                } else {
-                                    return Err(error);
-                                }
+                    Ok(data) => match tx.commit().await.map_err(Into::into) {
+                        Ok(()) => {
+                            return Ok(RoomGuard {
+                                data,
+                                _guard,
+                                _not_send: PhantomData,
+                            });
+                        }
+                        Err(error) => {
+                            if !self.retry_on_serialization_error(&error, i).await {
+                                return Err(error);
                             }
                         }
-                    }
+                    },
                     Err(error) => {
                         tx.rollback().await?;
-                        if is_serialization_error(&error) {
-                            // Retry (don't break the loop)
-                        } else {
+                        if !self.retry_on_serialization_error(&error, i).await {
                             return Err(error);
                         }
                     }
                 }
+                i += 1;
             }
         };
 
@@ -2954,14 +2943,14 @@ impl Database {
         Ok((tx, result))
     }
 
-    async fn run<F, T>(&self, future: F) -> T
+    async fn run<F, T>(&self, future: F) -> Result<T>
     where
-        F: Future<Output = T>,
+        F: Future<Output = Result<T>>,
     {
         #[cfg(test)]
         {
-            if let Some(background) = self.background.as_ref() {
-                background.simulate_random_delay().await;
+            if let Executor::Deterministic(executor) = &self.executor {
+                executor.simulate_random_delay().await;
             }
 
             self.runtime.as_ref().unwrap().block_on(future)
@@ -2972,6 +2961,27 @@ impl Database {
             future.await
         }
     }
+
+    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
+        // If the error is due to a failure to serialize concurrent transactions, then retry
+        // this transaction after a delay. With each subsequent retry, double the delay duration.
+        // Also vary the delay randomly in order to ensure different database connections retry
+        // at different times.
+        if is_serialization_error(error) {
+            let base_delay = 4_u64 << prev_attempt_count.min(16);
+            let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
+            log::info!(
+                "retrying transaction after serialization error. delay: {} ms.",
+                randomized_delay
+            );
+            self.executor
+                .sleep(Duration::from_millis(randomized_delay as u64))
+                .await;
+            true
+        } else {
+            false
+        }
+    }
 }
 
 fn is_serialization_error(error: &Error) -> bool {
@@ -3273,7 +3283,6 @@ mod test {
     use gpui::executor::Background;
     use lazy_static::lazy_static;
     use parking_lot::Mutex;
-    use rand::prelude::*;
     use sea_orm::ConnectionTrait;
     use sqlx::migrate::MigrateDatabase;
     use std::sync::Arc;
@@ -3295,7 +3304,9 @@ mod test {
             let mut db = runtime.block_on(async {
                 let mut options = ConnectOptions::new(url);
                 options.max_connections(5);
-                let db = Database::new(options).await.unwrap();
+                let db = Database::new(options, Executor::Deterministic(background))
+                    .await
+                    .unwrap();
                 let sql = include_str!(concat!(
                     env!("CARGO_MANIFEST_DIR"),
                     "/migrations.sqlite/20221109000000_test_schema.sql"
@@ -3310,7 +3321,6 @@ mod test {
                 db
             });
 
-            db.background = Some(background);
             db.runtime = Some(runtime);
 
             Self {
@@ -3344,13 +3354,14 @@ mod test {
                 options
                     .max_connections(5)
                     .idle_timeout(Duration::from_secs(0));
-                let db = Database::new(options).await.unwrap();
+                let db = Database::new(options, Executor::Deterministic(background))
+                    .await
+                    .unwrap();
                 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
                 db.migrate(Path::new(migrations_path), false).await.unwrap();
                 db
             });
 
-            db.background = Some(background);
             db.runtime = Some(runtime);
 
             Self {

crates/collab/src/lib.rs 🔗

@@ -10,6 +10,7 @@ mod tests;
 
 use axum::{http::StatusCode, response::IntoResponse};
 use db::Database;
+use executor::Executor;
 use serde::Deserialize;
 use std::{path::PathBuf, sync::Arc};
 
@@ -118,7 +119,7 @@ impl AppState {
     pub async fn new(config: Config) -> Result<Arc<Self>> {
         let mut db_options = db::ConnectOptions::new(config.database_url.clone());
         db_options.max_connections(config.database_max_connections);
-        let db = Database::new(db_options).await?;
+        let db = Database::new(db_options, Executor::Production).await?;
         let live_kit_client = if let Some(((server, key), secret)) = config
             .live_kit_server
             .as_ref()

crates/collab/src/main.rs 🔗

@@ -32,7 +32,7 @@ async fn main() -> Result<()> {
             let config = envy::from_env::<MigrateConfig>().expect("error loading config");
             let mut db_options = db::ConnectOptions::new(config.database_url.clone());
             db_options.max_connections(5);
-            let db = Database::new(db_options).await?;
+            let db = Database::new(db_options, Executor::Production).await?;
 
             let migrations_path = config
                 .migrations_path