collab fixes (#6720)

Conrad Irwin created

- Fail faster on serialization failure
- Move expensive participant update out of transaction

Release Notes:

- Fixed creating/moving channels in busy workspaces

Change summary

crates/collab/src/db.rs                     | 60 +++++++++++++-
crates/collab/src/db/queries/channels.rs    | 69 +++++------------
crates/collab/src/db/tests/message_tests.rs | 25 ++----
crates/collab/src/rpc.rs                    | 89 +++++++++++++---------
4 files changed, 135 insertions(+), 108 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -169,6 +169,30 @@ impl Database {
         self.run(body).await
     }
 
+    pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
+    where
+        F: Send + Fn(TransactionHandle) -> Fut,
+        Fut: Send + Future<Output = Result<T>>,
+    {
+        let body = async {
+            let (tx, result) = self.with_weak_transaction(&f).await?;
+            match result {
+                Ok(result) => match tx.commit().await.map_err(Into::into) {
+                    Ok(()) => return Ok(result),
+                    Err(error) => {
+                        return Err(error);
+                    }
+                },
+                Err(error) => {
+                    tx.rollback().await?;
+                    return Err(error);
+                }
+            }
+        };
+
+        self.run(body).await
+    }
+
     /// The same as room_transaction, but if you need to only optionally return a Room.
     async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
     where
@@ -284,6 +308,30 @@ impl Database {
         Ok((tx, result))
     }
 
+    async fn with_weak_transaction<F, Fut, T>(
+        &self,
+        f: &F,
+    ) -> Result<(DatabaseTransaction, Result<T>)>
+    where
+        F: Send + Fn(TransactionHandle) -> Fut,
+        Fut: Send + Future<Output = Result<T>>,
+    {
+        let tx = self
+            .pool
+            .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
+            .await?;
+
+        let mut tx = Arc::new(Some(tx));
+        let result = f(TransactionHandle(tx.clone())).await;
+        let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
+            return Err(anyhow!(
+                "couldn't complete transaction because it's still in use"
+            ))?;
+        };
+
+        Ok((tx, result))
+    }
+
     async fn run<F, T>(&self, future: F) -> Result<T>
     where
         F: Future<Output = Result<T>>,
@@ -303,13 +351,14 @@ impl Database {
         }
     }
 
-    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
+    async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: usize) -> bool {
         // If the error is due to a failure to serialize concurrent transactions, then retry
         // this transaction after a delay. With each subsequent retry, double the delay duration.
         // Also vary the delay randomly in order to ensure different database connections retry
         // at different times.
-        if is_serialization_error(error) {
-            let base_delay = 4_u64 << prev_attempt_count.min(16);
+        const SLEEPS: [f32; 10] = [10., 20., 40., 80., 160., 320., 640., 1280., 2560., 5120.];
+        if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() {
+            let base_delay = SLEEPS[prev_attempt_count];
             let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
             log::info!(
                 "retrying transaction after serialization error. delay: {} ms.",
@@ -456,9 +505,8 @@ pub struct NewUserResult {
 /// The result of moving a channel.
 #[derive(Debug)]
 pub struct MoveChannelResult {
-    pub participants_to_update: HashMap<UserId, ChannelsForUser>,
-    pub participants_to_remove: HashSet<UserId>,
-    pub moved_channels: HashSet<ChannelId>,
+    pub previous_participants: Vec<ChannelMember>,
+    pub descendent_ids: Vec<ChannelId>,
 }
 
 /// The result of renaming a channel.

crates/collab/src/db/queries/channels.rs 🔗

@@ -19,11 +19,7 @@ impl Database {
 
     #[cfg(test)]
     pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
-        Ok(self
-            .create_channel(name, None, creator_id)
-            .await?
-            .channel
-            .id)
+        Ok(self.create_channel(name, None, creator_id).await?.id)
     }
 
     #[cfg(test)]
@@ -36,7 +32,6 @@ impl Database {
         Ok(self
             .create_channel(name, Some(parent), creator_id)
             .await?
-            .channel
             .id)
     }
 
@@ -46,7 +41,7 @@ impl Database {
         name: &str,
         parent_channel_id: Option<ChannelId>,
         admin_id: UserId,
-    ) -> Result<CreateChannelResult> {
+    ) -> Result<Channel> {
         let name = Self::sanitize_channel_name(name)?;
         self.transaction(move |tx| async move {
             let mut parent = None;
@@ -72,14 +67,7 @@ impl Database {
             .insert(&*tx)
             .await?;
 
-            let participants_to_update;
-            if let Some(parent) = &parent {
-                participants_to_update = self
-                    .participants_to_notify_for_channel_change(parent, &*tx)
-                    .await?;
-            } else {
-                participants_to_update = vec![];
-
+            if parent.is_none() {
                 channel_member::ActiveModel {
                     id: ActiveValue::NotSet,
                     channel_id: ActiveValue::Set(channel.id),
@@ -89,12 +77,9 @@ impl Database {
                 }
                 .insert(&*tx)
                 .await?;
-            };
+            }
 
-            Ok(CreateChannelResult {
-                channel: Channel::from_model(channel, ChannelRole::Admin),
-                participants_to_update,
-            })
+            Ok(Channel::from_model(channel, ChannelRole::Admin))
         })
         .await
     }
@@ -718,6 +703,19 @@ impl Database {
         })
     }
 
+    pub async fn new_participants_to_notify(
+        &self,
+        parent_channel_id: ChannelId,
+    ) -> Result<Vec<(UserId, ChannelsForUser)>> {
+        self.weak_transaction(|tx| async move {
+            let parent_channel = self.get_channel_internal(parent_channel_id, &*tx).await?;
+            self.participants_to_notify_for_channel_change(&parent_channel, &*tx)
+                .await
+        })
+        .await
+    }
+
+    // TODO: this is very expensive, and we should rethink
     async fn participants_to_notify_for_channel_change(
         &self,
         new_parent: &channel::Model,
@@ -1287,7 +1285,7 @@ impl Database {
 
             let mut model = channel.into_active_model();
             model.parent_path = ActiveValue::Set(new_parent_path);
-            let channel = model.update(&*tx).await?;
+            model.update(&*tx).await?;
 
             if new_parent_channel.is_none() {
                 channel_member::ActiveModel {
@@ -1314,34 +1312,9 @@ impl Database {
                 .all(&*tx)
                 .await?;
 
-            let participants_to_update: HashMap<_, _> = self
-                .participants_to_notify_for_channel_change(
-                    new_parent_channel.as_ref().unwrap_or(&channel),
-                    &*tx,
-                )
-                .await?
-                .into_iter()
-                .collect();
-
-            let mut moved_channels: HashSet<ChannelId> = HashSet::default();
-            for id in descendent_ids {
-                moved_channels.insert(id);
-            }
-            moved_channels.insert(channel_id);
-
-            let mut participants_to_remove: HashSet<UserId> = HashSet::default();
-            for participant in previous_participants {
-                if participant.kind == proto::channel_member::Kind::AncestorMember {
-                    if !participants_to_update.contains_key(&participant.user_id) {
-                        participants_to_remove.insert(participant.user_id);
-                    }
-                }
-            }
-
             Ok(Some(MoveChannelResult {
-                participants_to_remove,
-                participants_to_update,
-                moved_channels,
+                previous_participants,
+                descendent_ids,
             }))
         })
         .await

crates/collab/src/db/tests/message_tests.rs 🔗

@@ -15,22 +15,18 @@ test_both_dbs!(
 
 async fn test_channel_message_retrieval(db: &Arc<Database>) {
     let user = new_test_user(db, "user@example.com").await;
-    let result = db.create_channel("channel", None, user).await.unwrap();
+    let channel = db.create_channel("channel", None, user).await.unwrap();
 
     let owner_id = db.create_server("test").await.unwrap().0 as u32;
-    db.join_channel_chat(
-        result.channel.id,
-        rpc::ConnectionId { owner_id, id: 0 },
-        user,
-    )
-    .await
-    .unwrap();
+    db.join_channel_chat(channel.id, rpc::ConnectionId { owner_id, id: 0 }, user)
+        .await
+        .unwrap();
 
     let mut all_messages = Vec::new();
     for i in 0..10 {
         all_messages.push(
             db.create_channel_message(
-                result.channel.id,
+                channel.id,
                 user,
                 &i.to_string(),
                 &[],
@@ -45,7 +41,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
     }
 
     let messages = db
-        .get_channel_messages(result.channel.id, user, 3, None)
+        .get_channel_messages(channel.id, user, 3, None)
         .await
         .unwrap()
         .into_iter()
@@ -55,7 +51,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
 
     let messages = db
         .get_channel_messages(
-            result.channel.id,
+            channel.id,
             user,
             4,
             Some(MessageId::from_proto(all_messages[6])),
@@ -366,12 +362,7 @@ async fn test_channel_message_mentions(db: &Arc<Database>) {
     let user_b = new_test_user(db, "user_b@example.com").await;
     let user_c = new_test_user(db, "user_c@example.com").await;
 
-    let channel = db
-        .create_channel("channel", None, user_a)
-        .await
-        .unwrap()
-        .channel
-        .id;
+    let channel = db.create_channel("channel", None, user_a).await.unwrap().id;
     db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
         .await
         .unwrap();

crates/collab/src/rpc.rs 🔗

@@ -3,11 +3,10 @@ mod connection_pool;
 use crate::{
     auth::{self, Impersonator},
     db::{
-        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
-        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
-        MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
-        RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
-        User, UserId,
+        self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
+        InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
+        RemoveChannelMemberResult, RenameChannelResult, RespondToChannelInvite, RoomId, ServerId,
+        SetChannelVisibilityResult, User, UserId,
     },
     executor::Executor,
     AppState, Error, Result,
@@ -2301,10 +2300,7 @@ async fn create_channel(
     let db = session.db().await;
 
     let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
-    let CreateChannelResult {
-        channel,
-        participants_to_update,
-    } = db
+    let channel = db
         .create_channel(&request.name, parent_id, session.user_id)
         .await?;
 
@@ -2313,6 +2309,13 @@ async fn create_channel(
         parent_id: request.parent_id,
     })?;
 
+    let participants_to_update;
+    if let Some(parent) = parent_id {
+        participants_to_update = db.new_participants_to_notify(parent).await?;
+    } else {
+        participants_to_update = vec![];
+    }
+
     let connection_pool = session.connection_pool().await;
     for (user_id, channels) in participants_to_update {
         let update = build_channels_update(channels, vec![]);
@@ -2572,41 +2575,53 @@ async fn move_channel(
         .move_channel(channel_id, to, session.user_id)
         .await?;
 
-    notify_channel_moved(result, session).await?;
+    if let Some(result) = result {
+        let participants_to_update: HashMap<_, _> = session
+            .db()
+            .await
+            .new_participants_to_notify(to.unwrap_or(channel_id))
+            .await?
+            .into_iter()
+            .collect();
 
-    response.send(Ack {})?;
-    Ok(())
-}
+        let mut moved_channels: HashSet<ChannelId> = HashSet::default();
+        for id in result.descendent_ids {
+            moved_channels.insert(id);
+        }
+        moved_channels.insert(channel_id);
 
-async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
-    let Some(MoveChannelResult {
-        participants_to_remove,
-        participants_to_update,
-        moved_channels,
-    }) = result
-    else {
-        return Ok(());
-    };
-    let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
+        let mut participants_to_remove: HashSet<UserId> = HashSet::default();
+        for participant in result.previous_participants {
+            if participant.kind == proto::channel_member::Kind::AncestorMember {
+                if !participants_to_update.contains_key(&participant.user_id) {
+                    participants_to_remove.insert(participant.user_id);
+                }
+            }
+        }
 
-    let connection_pool = session.connection_pool().await;
-    for (user_id, channels) in participants_to_update {
-        let mut update = build_channels_update(channels, vec![]);
-        update.delete_channels = moved_channels.clone();
-        for connection_id in connection_pool.user_connection_ids(user_id) {
-            session.peer.send(connection_id, update.clone())?;
+        let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
+
+        let connection_pool = session.connection_pool().await;
+        for (user_id, channels) in participants_to_update {
+            let mut update = build_channels_update(channels, vec![]);
+            update.delete_channels = moved_channels.clone();
+            for connection_id in connection_pool.user_connection_ids(user_id) {
+                session.peer.send(connection_id, update.clone())?;
+            }
         }
-    }
 
-    for user_id in participants_to_remove {
-        let update = proto::UpdateChannels {
-            delete_channels: moved_channels.clone(),
-            ..Default::default()
-        };
-        for connection_id in connection_pool.user_connection_ids(user_id) {
-            session.peer.send(connection_id, update.clone())?;
+        for user_id in participants_to_remove {
+            let update = proto::UpdateChannels {
+                delete_channels: moved_channels.clone(),
+                ..Default::default()
+            };
+            for connection_id in connection_pool.user_connection_ids(user_id) {
+                session.peer.send(connection_id, update.clone())?;
+            }
         }
     }
+
+    response.send(Ack {})?;
     Ok(())
 }