From fd8e6110b1328e8ee6d8828d28620c7b53574167 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 20 Oct 2023 14:31:10 -0600 Subject: [PATCH] Fix panic by disallowing multiple room joins --- crates/call/src/call.rs | 89 ++++++++++++++++++++++++++++--- crates/call/src/room.rs | 44 +++++++-------- crates/workspace/src/workspace.rs | 4 ++ 3 files changed, 104 insertions(+), 33 deletions(-) diff --git a/crates/call/src/call.rs b/crates/call/src/call.rs index 08463413257b4b10972ad08ab7c64d181599d914..98233e1a10b896073b68f91fa4e52d70aa649f3f 100644 --- a/crates/call/src/call.rs +++ b/crates/call/src/call.rs @@ -10,7 +10,7 @@ use client::{ ZED_ALWAYS_ACTIVE, }; use collections::HashSet; -use futures::{future::Shared, FutureExt}; +use futures::{channel::oneshot, future::Shared, Future, FutureExt}; use gpui::{ AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Subscription, Task, WeakModelHandle, @@ -37,6 +37,31 @@ pub struct IncomingCall { pub initial_project: Option, } +pub struct OneAtATime { + cancel: Option>, +} + +impl OneAtATime { + /// spawn a task in the given context. + /// if another task is spawned before that resolves, or if the OneAtATime itself is dropped, the first task will be cancelled and return Ok(None) + /// otherwise you'll see the result of the task. + fn spawn(&mut self, cx: &mut AppContext, f: F) -> Task>> + where + F: 'static + FnOnce(AsyncAppContext) -> Fut, + Fut: Future>, + R: 'static, + { + let (tx, rx) = oneshot::channel(); + self.cancel.replace(tx); + cx.spawn(|cx| async move { + futures::select_biased! { + _ = rx.fuse() => Ok(None), + result = f(cx).fuse() => result.map(Some), + } + }) + } +} + /// Singleton global maintaining the user's participation in a room across workspaces. pub struct ActiveCall { room: Option<(ModelHandle, Vec)>, @@ -49,6 +74,7 @@ pub struct ActiveCall { ), client: Arc, user_store: ModelHandle, + _join_debouncer: OneAtATime, _subscriptions: Vec, } @@ -69,6 +95,7 @@ impl ActiveCall { pending_invites: Default::default(), incoming_call: watch::channel(), + _join_debouncer: OneAtATime { cancel: None }, _subscriptions: vec![ client.add_request_handler(cx.handle(), Self::handle_incoming_call), client.add_message_handler(cx.handle(), Self::handle_call_canceled), @@ -259,11 +286,16 @@ impl ActiveCall { return Task::ready(Err(anyhow!("no incoming call"))); }; - let join = Room::join(&call, self.client.clone(), self.user_store.clone(), cx); + let room_id = call.room_id.clone(); + let client = self.client.clone(); + let user_store = self.user_store.clone(); + let join = self + ._join_debouncer + .spawn(cx, move |cx| Room::join(room_id, client, user_store, cx)); cx.spawn(|this, mut cx| async move { let room = join.await?; - this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx)) + this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx)) .await?; this.update(&mut cx, |this, cx| { this.report_call_event("accept incoming", cx) @@ -290,20 +322,24 @@ impl ActiveCall { &mut self, channel_id: u64, cx: &mut ModelContext, - ) -> Task>> { + ) -> Task>>> { if let Some(room) = self.room().cloned() { if room.read(cx).channel_id() == Some(channel_id) { - return Task::ready(Ok(room)); + return Task::ready(Ok(Some(room))); } else { room.update(cx, |room, cx| room.clear_state(cx)); } } - let join = Room::join_channel(channel_id, self.client.clone(), self.user_store.clone(), cx); + let client = self.client.clone(); + let user_store = self.user_store.clone(); + let join = self._join_debouncer.spawn(cx, move |cx| async move { + Room::join_channel(channel_id, client, user_store, cx).await + }); - cx.spawn(|this, mut cx| async move { + cx.spawn(move |this, mut cx| async move { let room = join.await?; - this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx)) + this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx)) .await?; this.update(&mut cx, |this, cx| { this.report_call_event("join channel", cx) @@ -457,3 +493,40 @@ pub fn report_call_event_for_channel( }; telemetry.report_clickhouse_event(event, telemetry_settings); } + +#[cfg(test)] +mod test { + use gpui::TestAppContext; + + use crate::OneAtATime; + + #[gpui::test] + async fn test_one_at_a_time(cx: &mut TestAppContext) { + let mut one_at_a_time = OneAtATime { cancel: None }; + + assert_eq!( + cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(1) })) + .await + .unwrap(), + Some(1) + ); + + let (a, b) = cx.update(|cx| { + ( + one_at_a_time.spawn(cx, |_| async { + assert!(false); + Ok(2) + }), + one_at_a_time.spawn(cx, |_| async { Ok(3) }), + ) + }); + + assert_eq!(a.await.unwrap(), None); + assert_eq!(b.await.unwrap(), Some(3)); + + let promise = cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(4) })); + drop(one_at_a_time); + + assert_eq!(promise.await.unwrap(), None); + } +} diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 4e52f57f60b36c0f83400279c2feb45ee9e2b713..4cd815e856ee44c75bf7b8840d678863cb5c4cc3 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -1,7 +1,6 @@ use crate::{ call_settings::CallSettings, participant::{LocalParticipant, ParticipantLocation, RemoteParticipant, RemoteVideoTrack}, - IncomingCall, }; use anyhow::{anyhow, Result}; use audio::{Audio, Sound}; @@ -284,37 +283,32 @@ impl Room { }) } - pub(crate) fn join_channel( + pub(crate) async fn join_channel( channel_id: u64, client: Arc, user_store: ModelHandle, - cx: &mut AppContext, - ) -> Task>> { - cx.spawn(|cx| async move { - Self::from_join_response( - client.request(proto::JoinChannel { channel_id }).await?, - client, - user_store, - cx, - ) - }) + cx: AsyncAppContext, + ) -> Result> { + Self::from_join_response( + client.request(proto::JoinChannel { channel_id }).await?, + client, + user_store, + cx, + ) } - pub(crate) fn join( - call: &IncomingCall, + pub(crate) async fn join( + room_id: u64, client: Arc, user_store: ModelHandle, - cx: &mut AppContext, - ) -> Task>> { - let id = call.room_id; - cx.spawn(|cx| async move { - Self::from_join_response( - client.request(proto::JoinRoom { id }).await?, - client, - user_store, - cx, - ) - }) + cx: AsyncAppContext, + ) -> Result> { + Self::from_join_response( + client.request(proto::JoinRoom { id: room_id }).await?, + client, + user_store, + cx, + ) } pub fn mute_on_join(cx: &AppContext) -> bool { diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index e879b981ef65f55ab221ddd888b84d025dbd1ef7..f0b2c7b0765d61b9a5173f1e97fa15e04faf20e5 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -4238,6 +4238,10 @@ async fn join_channel_internal( }) .await?; + let Some(room) = room else { + return anyhow::Ok(true); + }; + room.update(cx, |room, _| room.room_update_completed()) .await;