Fix panic by disallowing multiple room joins (#3149)

Conrad Irwin created

Release Notes:

- Fixed panic that could occur when switching channels quickly

Change summary

crates/call/src/call.rs                      | 107 ++++++++++++++++++-
crates/call/src/room.rs                      |  44 +++----
crates/collab/src/tests.rs                   |   4 
crates/collab/src/tests/integration_tests.rs | 115 +++++++++++++++++++++
crates/workspace/src/workspace.rs            |   4 
5 files changed, 240 insertions(+), 34 deletions(-)

Detailed changes

crates/call/src/call.rs 🔗

@@ -10,7 +10,7 @@ use client::{
     ZED_ALWAYS_ACTIVE,
 };
 use collections::HashSet;
-use futures::{future::Shared, FutureExt};
+use futures::{channel::oneshot, future::Shared, Future, FutureExt};
 use gpui::{
     AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Subscription, Task,
     WeakModelHandle,
@@ -37,10 +37,42 @@ pub struct IncomingCall {
     pub initial_project: Option<proto::ParticipantProject>,
 }
 
+pub struct OneAtATime {
+    cancel: Option<oneshot::Sender<()>>,
+}
+
+impl OneAtATime {
+    /// spawn a task in the given context.
+    /// if another task is spawned before that resolves, or if the OneAtATime itself is dropped, the first task will be cancelled and return Ok(None)
+    /// otherwise you'll see the result of the task.
+    fn spawn<F, Fut, R>(&mut self, cx: &mut AppContext, f: F) -> Task<Result<Option<R>>>
+    where
+        F: 'static + FnOnce(AsyncAppContext) -> Fut,
+        Fut: Future<Output = Result<R>>,
+        R: 'static,
+    {
+        let (tx, rx) = oneshot::channel();
+        self.cancel.replace(tx);
+        cx.spawn(|cx| async move {
+            futures::select_biased! {
+                _ = rx.fuse() => Ok(None),
+                result = f(cx).fuse() => result.map(Some),
+            }
+        })
+    }
+
+    fn running(&self) -> bool {
+        self.cancel
+            .as_ref()
+            .is_some_and(|cancel| !cancel.is_canceled())
+    }
+}
+
 /// Singleton global maintaining the user's participation in a room across workspaces.
 pub struct ActiveCall {
     room: Option<(ModelHandle<Room>, Vec<Subscription>)>,
     pending_room_creation: Option<Shared<Task<Result<ModelHandle<Room>, Arc<anyhow::Error>>>>>,
+    _join_debouncer: OneAtATime,
     location: Option<WeakModelHandle<Project>>,
     pending_invites: HashSet<u64>,
     incoming_call: (
@@ -69,6 +101,7 @@ impl ActiveCall {
             pending_invites: Default::default(),
             incoming_call: watch::channel(),
 
+            _join_debouncer: OneAtATime { cancel: None },
             _subscriptions: vec![
                 client.add_request_handler(cx.handle(), Self::handle_incoming_call),
                 client.add_message_handler(cx.handle(), Self::handle_call_canceled),
@@ -143,6 +176,10 @@ impl ActiveCall {
         }
         cx.notify();
 
+        if self._join_debouncer.running() {
+            return Task::ready(Ok(()));
+        }
+
         let room = if let Some(room) = self.room().cloned() {
             Some(Task::ready(Ok(room)).shared())
         } else {
@@ -259,11 +296,20 @@ impl ActiveCall {
             return Task::ready(Err(anyhow!("no incoming call")));
         };
 
-        let join = Room::join(&call, self.client.clone(), self.user_store.clone(), cx);
+        if self.pending_room_creation.is_some() {
+            return Task::ready(Ok(()));
+        }
+
+        let room_id = call.room_id.clone();
+        let client = self.client.clone();
+        let user_store = self.user_store.clone();
+        let join = self
+            ._join_debouncer
+            .spawn(cx, move |cx| Room::join(room_id, client, user_store, cx));
 
         cx.spawn(|this, mut cx| async move {
             let room = join.await?;
-            this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))
+            this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx))
                 .await?;
             this.update(&mut cx, |this, cx| {
                 this.report_call_event("accept incoming", cx)
@@ -290,20 +336,28 @@ impl ActiveCall {
         &mut self,
         channel_id: u64,
         cx: &mut ModelContext<Self>,
-    ) -> Task<Result<ModelHandle<Room>>> {
+    ) -> Task<Result<Option<ModelHandle<Room>>>> {
         if let Some(room) = self.room().cloned() {
             if room.read(cx).channel_id() == Some(channel_id) {
-                return Task::ready(Ok(room));
+                return Task::ready(Ok(Some(room)));
             } else {
                 room.update(cx, |room, cx| room.clear_state(cx));
             }
         }
 
-        let join = Room::join_channel(channel_id, self.client.clone(), self.user_store.clone(), cx);
+        if self.pending_room_creation.is_some() {
+            return Task::ready(Ok(None));
+        }
 
-        cx.spawn(|this, mut cx| async move {
+        let client = self.client.clone();
+        let user_store = self.user_store.clone();
+        let join = self._join_debouncer.spawn(cx, move |cx| async move {
+            Room::join_channel(channel_id, client, user_store, cx).await
+        });
+
+        cx.spawn(move |this, mut cx| async move {
             let room = join.await?;
-            this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))
+            this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx))
                 .await?;
             this.update(&mut cx, |this, cx| {
                 this.report_call_event("join channel", cx)
@@ -457,3 +511,40 @@ pub fn report_call_event_for_channel(
     };
     telemetry.report_clickhouse_event(event, telemetry_settings);
 }
+
+#[cfg(test)]
+mod test {
+    use gpui::TestAppContext;
+
+    use crate::OneAtATime;
+
+    #[gpui::test]
+    async fn test_one_at_a_time(cx: &mut TestAppContext) {
+        let mut one_at_a_time = OneAtATime { cancel: None };
+
+        assert_eq!(
+            cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(1) }))
+                .await
+                .unwrap(),
+            Some(1)
+        );
+
+        let (a, b) = cx.update(|cx| {
+            (
+                one_at_a_time.spawn(cx, |_| async {
+                    assert!(false);
+                    Ok(2)
+                }),
+                one_at_a_time.spawn(cx, |_| async { Ok(3) }),
+            )
+        });
+
+        assert_eq!(a.await.unwrap(), None);
+        assert_eq!(b.await.unwrap(), Some(3));
+
+        let promise = cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(4) }));
+        drop(one_at_a_time);
+
+        assert_eq!(promise.await.unwrap(), None);
+    }
+}

crates/call/src/room.rs 🔗

@@ -1,7 +1,6 @@
 use crate::{
     call_settings::CallSettings,
     participant::{LocalParticipant, ParticipantLocation, RemoteParticipant, RemoteVideoTrack},
-    IncomingCall,
 };
 use anyhow::{anyhow, Result};
 use audio::{Audio, Sound};
@@ -291,37 +290,32 @@ impl Room {
         })
     }
 
-    pub(crate) fn join_channel(
+    pub(crate) async fn join_channel(
         channel_id: u64,
         client: Arc<Client>,
         user_store: ModelHandle<UserStore>,
-        cx: &mut AppContext,
-    ) -> Task<Result<ModelHandle<Self>>> {
-        cx.spawn(|cx| async move {
-            Self::from_join_response(
-                client.request(proto::JoinChannel { channel_id }).await?,
-                client,
-                user_store,
-                cx,
-            )
-        })
+        cx: AsyncAppContext,
+    ) -> Result<ModelHandle<Self>> {
+        Self::from_join_response(
+            client.request(proto::JoinChannel { channel_id }).await?,
+            client,
+            user_store,
+            cx,
+        )
     }
 
-    pub(crate) fn join(
-        call: &IncomingCall,
+    pub(crate) async fn join(
+        room_id: u64,
         client: Arc<Client>,
         user_store: ModelHandle<UserStore>,
-        cx: &mut AppContext,
-    ) -> Task<Result<ModelHandle<Self>>> {
-        let id = call.room_id;
-        cx.spawn(|cx| async move {
-            Self::from_join_response(
-                client.request(proto::JoinRoom { id }).await?,
-                client,
-                user_store,
-                cx,
-            )
-        })
+        cx: AsyncAppContext,
+    ) -> Result<ModelHandle<Self>> {
+        Self::from_join_response(
+            client.request(proto::JoinRoom { id: room_id }).await?,
+            client,
+            user_store,
+            cx,
+        )
     }
 
     pub fn mute_on_join(cx: &AppContext) -> bool {

crates/collab/src/tests.rs 🔗

@@ -40,3 +40,7 @@ fn room_participants(room: &ModelHandle<Room>, cx: &mut TestAppContext) -> RoomP
         RoomParticipants { remote, pending }
     })
 }
+
+fn channel_id(room: &ModelHandle<Room>, cx: &mut TestAppContext) -> Option<u64> {
+    cx.read(|cx| room.read(cx).channel_id())
+}

crates/collab/src/tests/integration_tests.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
     rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
-    tests::{room_participants, RoomParticipants, TestClient, TestServer},
+    tests::{channel_id, room_participants, RoomParticipants, TestClient, TestServer},
 };
 use call::{room, ActiveCall, ParticipantLocation, Room};
 use client::{User, RECEIVE_TIMEOUT};
@@ -469,6 +469,119 @@ async fn test_calling_multiple_users_simultaneously(
     );
 }
 
+#[gpui::test(iterations = 10)]
+async fn test_joining_channels_and_calling_multiple_users_simultaneously(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+    cx_c: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+    let client_c = server.create_client(cx_c, "user_c").await;
+    server
+        .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)])
+        .await;
+
+    let channel_1 = server
+        .make_channel(
+            "channel1",
+            None,
+            (&client_a, cx_a),
+            &mut [(&client_b, cx_b), (&client_c, cx_c)],
+        )
+        .await;
+
+    let channel_2 = server
+        .make_channel(
+            "channel2",
+            None,
+            (&client_a, cx_a),
+            &mut [(&client_b, cx_b), (&client_c, cx_c)],
+        )
+        .await;
+
+    let active_call_a = cx_a.read(ActiveCall::global);
+
+    // Simultaneously join channel 1 and then channel 2
+    active_call_a
+        .update(cx_a, |call, cx| call.join_channel(channel_1, cx))
+        .detach();
+    let join_channel_2 = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_2, cx));
+
+    join_channel_2.await.unwrap();
+
+    let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone());
+    deterministic.run_until_parked();
+
+    assert_eq!(channel_id(&room_a, cx_a), Some(channel_2));
+
+    // Leave the room
+    active_call_a
+        .update(cx_a, |call, cx| {
+            let hang_up = call.hang_up(cx);
+            hang_up
+        })
+        .await
+        .unwrap();
+
+    // Initiating invites and then joining a channel should fail gracefully
+    let b_invite = active_call_a.update(cx_a, |call, cx| {
+        call.invite(client_b.user_id().unwrap(), None, cx)
+    });
+    let c_invite = active_call_a.update(cx_a, |call, cx| {
+        call.invite(client_c.user_id().unwrap(), None, cx)
+    });
+
+    let join_channel = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_1, cx));
+
+    b_invite.await.unwrap();
+    c_invite.await.unwrap();
+    join_channel.await.unwrap();
+
+    let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone());
+    deterministic.run_until_parked();
+
+    assert_eq!(
+        room_participants(&room_a, cx_a),
+        RoomParticipants {
+            remote: Default::default(),
+            pending: vec!["user_b".to_string(), "user_c".to_string()]
+        }
+    );
+
+    assert_eq!(channel_id(&room_a, cx_a), None);
+
+    // Leave the room
+    active_call_a
+        .update(cx_a, |call, cx| {
+            let hang_up = call.hang_up(cx);
+            hang_up
+        })
+        .await
+        .unwrap();
+
+    // Simultaneously join channel 1 and call user B and user C from client A.
+    let join_channel = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_1, cx));
+
+    let b_invite = active_call_a.update(cx_a, |call, cx| {
+        call.invite(client_b.user_id().unwrap(), None, cx)
+    });
+    let c_invite = active_call_a.update(cx_a, |call, cx| {
+        call.invite(client_c.user_id().unwrap(), None, cx)
+    });
+
+    join_channel.await.unwrap();
+    b_invite.await.unwrap();
+    c_invite.await.unwrap();
+
+    active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone());
+    deterministic.run_until_parked();
+}
+
 #[gpui::test(iterations = 10)]
 async fn test_room_uniqueness(
     deterministic: Arc<Deterministic>,

crates/workspace/src/workspace.rs 🔗

@@ -4238,6 +4238,10 @@ async fn join_channel_internal(
         })
         .await?;
 
+    let Some(room) = room else {
+        return anyhow::Ok(true);
+    };
+
     room.update(cx, |room, _| room.room_update_completed())
         .await;