Leave Zed room when LiveKit room disconnects

Antonio Scandurra created

Change summary

crates/call/src/room.rs                                                        | 25 
crates/collab/src/integration_tests.rs                                         | 70 
crates/live_kit_client/Cargo.toml                                              |  2 
crates/live_kit_client/LiveKitBridge/Sources/LiveKitBridge/LiveKitBridge.swift | 29 
crates/live_kit_client/src/prod.rs                                             | 46 
crates/live_kit_client/src/test.rs                                             | 63 
crates/live_kit_server/src/api.rs                                              |  3 
7 files changed, 200 insertions(+), 38 deletions(-)

Detailed changes

crates/call/src/room.rs 🔗

@@ -8,6 +8,7 @@ use collections::{BTreeMap, HashSet};
 use futures::StreamExt;
 use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
 use live_kit_client::{LocalTrackPublication, LocalVideoTrack, RemoteVideoTrackUpdate};
+use postage::stream::Stream;
 use project::Project;
 use std::{mem, os::unix::prelude::OsStrExt, sync::Arc};
 use util::{post_inc, ResultExt};
@@ -80,8 +81,26 @@ impl Room {
 
         let live_kit_room = if let Some(connection_info) = live_kit_connection_info {
             let room = live_kit_client::Room::new();
-            let mut track_changes = room.remote_video_track_updates();
+            let mut status = room.status();
+            // Consume the initial status of the room.
+            let _ = status.try_recv();
             let _maintain_room = cx.spawn_weak(|this, mut cx| async move {
+                while let Some(status) = status.next().await {
+                    let this = if let Some(this) = this.upgrade(&cx) {
+                        this
+                    } else {
+                        break;
+                    };
+
+                    if status == live_kit_client::ConnectionState::Disconnected {
+                        this.update(&mut cx, |this, cx| this.leave(cx).log_err());
+                        break;
+                    }
+                }
+            });
+
+            let mut track_changes = room.remote_video_track_updates();
+            let _maintain_tracks = cx.spawn_weak(|this, mut cx| async move {
                 while let Some(track_change) = track_changes.next().await {
                     let this = if let Some(this) = this.upgrade(&cx) {
                         this
@@ -94,14 +113,17 @@ impl Room {
                     });
                 }
             });
+
             cx.foreground()
                 .spawn(room.connect(&connection_info.server_url, &connection_info.token))
                 .detach_and_log_err(cx);
+
             Some(LiveKitRoom {
                 room,
                 screen_track: ScreenTrack::None,
                 next_publish_id: 0,
                 _maintain_room,
+                _maintain_tracks,
             })
         } else {
             None
@@ -725,6 +747,7 @@ struct LiveKitRoom {
     screen_track: ScreenTrack,
     next_publish_id: usize,
     _maintain_room: Task<()>,
+    _maintain_tracks: Task<()>,
 }
 
 pub enum ScreenTrack {

crates/collab/src/integration_tests.rs 🔗

@@ -253,12 +253,13 @@ async fn test_basic_calls(
         }
     );
 
-    // User B leaves the room.
-    active_call_b.update(cx_b, |call, cx| {
-        call.hang_up(cx).unwrap();
-        assert!(call.room().is_none());
-    });
-    deterministic.run_until_parked();
+    // User B gets disconnected from the LiveKit server, which causes them
+    // to automatically leave the room.
+    server
+        .test_live_kit_server
+        .disconnect_client(client_b.peer_id().unwrap().to_string())
+        .await;
+    active_call_b.update(cx_b, |call, _| assert!(call.room().is_none()));
     assert_eq!(
         room_participants(&room_a, cx_a),
         RoomParticipants {
@@ -452,6 +453,63 @@ async fn test_leaving_room_on_disconnection(
             pending: Default::default()
         }
     );
+
+    // Call user B again from client A.
+    active_call_a
+        .update(cx_a, |call, cx| {
+            call.invite(client_b.user_id().unwrap(), None, cx)
+        })
+        .await
+        .unwrap();
+    let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone());
+
+    // User B receives the call and joins the room.
+    let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming());
+    incoming_call_b.next().await.unwrap().unwrap();
+    active_call_b
+        .update(cx_b, |call, cx| call.accept_incoming(cx))
+        .await
+        .unwrap();
+    let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone());
+    deterministic.run_until_parked();
+    assert_eq!(
+        room_participants(&room_a, cx_a),
+        RoomParticipants {
+            remote: vec!["user_b".to_string()],
+            pending: Default::default()
+        }
+    );
+    assert_eq!(
+        room_participants(&room_b, cx_b),
+        RoomParticipants {
+            remote: vec!["user_a".to_string()],
+            pending: Default::default()
+        }
+    );
+
+    // User B gets disconnected from the LiveKit server, which causes it
+    // to automatically leave the room.
+    server
+        .test_live_kit_server
+        .disconnect_client(client_b.peer_id().unwrap().to_string())
+        .await;
+    deterministic.run_until_parked();
+    active_call_a.update(cx_a, |call, _| assert!(call.room().is_none()));
+    active_call_b.update(cx_b, |call, _| assert!(call.room().is_none()));
+    assert_eq!(
+        room_participants(&room_a, cx_a),
+        RoomParticipants {
+            remote: Default::default(),
+            pending: Default::default()
+        }
+    );
+    assert_eq!(
+        room_participants(&room_b, cx_b),
+        RoomParticipants {
+            remote: Default::default(),
+            pending: Default::default()
+        }
+    );
 }
 
 #[gpui::test(iterations = 10)]

crates/live_kit_client/Cargo.toml 🔗

@@ -34,6 +34,7 @@ core-graphics = "0.22.3"
 futures = "0.3"
 log = { version = "0.4.16", features = ["kv_unstable_serde"] }
 parking_lot = "0.11.1"
+postage = { version = "0.4.1", features = ["futures-traits"] }
 
 async-trait = { version = "0.1", optional = true }
 lazy_static = { version = "1.4", optional = true }
@@ -60,7 +61,6 @@ jwt = "0.16"
 lazy_static = "1.4"
 objc = "0.2"
 parking_lot = "0.11.1"
-postage = { version = "0.4.1", features = ["futures-traits"] }
 serde = { version = "1.0", features = ["derive", "rc"] }
 sha2 = "0.10"
 simplelog = "0.9"

crates/live_kit_client/LiveKitBridge/Sources/LiveKitBridge/LiveKitBridge.swift 🔗

@@ -5,15 +5,28 @@ import ScreenCaptureKit
 
 class LKRoomDelegate: RoomDelegate {
     var data: UnsafeRawPointer
+    var onDidDisconnect: @convention(c) (UnsafeRawPointer) -> Void
     var onDidSubscribeToRemoteVideoTrack: @convention(c) (UnsafeRawPointer, CFString, CFString, UnsafeRawPointer) -> Void
     var onDidUnsubscribeFromRemoteVideoTrack: @convention(c) (UnsafeRawPointer, CFString, CFString) -> Void
     
-    init(data: UnsafeRawPointer, onDidSubscribeToRemoteVideoTrack: @escaping @convention(c) (UnsafeRawPointer, CFString, CFString, UnsafeRawPointer) -> Void, onDidUnsubscribeFromRemoteVideoTrack: @escaping @convention(c) (UnsafeRawPointer, CFString, CFString) -> Void) {
+    init(
+        data: UnsafeRawPointer,
+        onDidDisconnect: @escaping @convention(c) (UnsafeRawPointer) -> Void,
+        onDidSubscribeToRemoteVideoTrack: @escaping @convention(c) (UnsafeRawPointer, CFString, CFString, UnsafeRawPointer) -> Void,
+        onDidUnsubscribeFromRemoteVideoTrack: @escaping @convention(c) (UnsafeRawPointer, CFString, CFString) -> Void)
+    {
         self.data = data
+        self.onDidDisconnect = onDidDisconnect
         self.onDidSubscribeToRemoteVideoTrack = onDidSubscribeToRemoteVideoTrack
         self.onDidUnsubscribeFromRemoteVideoTrack = onDidUnsubscribeFromRemoteVideoTrack
     }
 
+    func room(_ room: Room, didUpdate connectionState: ConnectionState, oldValue: ConnectionState) {
+        if connectionState.isDisconnected {
+            self.onDidDisconnect(self.data)
+        }
+    }
+
     func room(_ room: Room, participant: RemoteParticipant, didSubscribe publication: RemoteTrackPublication, track: Track) {
         if track.kind == .video {
             self.onDidSubscribeToRemoteVideoTrack(self.data, participant.identity as CFString, track.sid! as CFString, Unmanaged.passUnretained(track).toOpaque())
@@ -62,8 +75,18 @@ class LKVideoRenderer: NSObject, VideoRenderer {
 }
 
 @_cdecl("LKRoomDelegateCreate")
-public func LKRoomDelegateCreate(data: UnsafeRawPointer, onDidSubscribeToRemoteVideoTrack: @escaping @convention(c) (UnsafeRawPointer, CFString, CFString, UnsafeRawPointer) -> Void, onDidUnsubscribeFromRemoteVideoTrack: @escaping @convention(c) (UnsafeRawPointer, CFString, CFString) -> Void) -> UnsafeMutableRawPointer {
-    let delegate = LKRoomDelegate(data: data, onDidSubscribeToRemoteVideoTrack: onDidSubscribeToRemoteVideoTrack, onDidUnsubscribeFromRemoteVideoTrack: onDidUnsubscribeFromRemoteVideoTrack)
+public func LKRoomDelegateCreate(
+    data: UnsafeRawPointer,
+    onDidDisconnect: @escaping @convention(c) (UnsafeRawPointer) -> Void,
+    onDidSubscribeToRemoteVideoTrack: @escaping @convention(c) (UnsafeRawPointer, CFString, CFString, UnsafeRawPointer) -> Void,
+    onDidUnsubscribeFromRemoteVideoTrack: @escaping @convention(c) (UnsafeRawPointer, CFString, CFString) -> Void
+) -> UnsafeMutableRawPointer {
+    let delegate = LKRoomDelegate(
+        data: data,
+        onDidDisconnect: onDidDisconnect,
+        onDidSubscribeToRemoteVideoTrack: onDidSubscribeToRemoteVideoTrack,
+        onDidUnsubscribeFromRemoteVideoTrack: onDidUnsubscribeFromRemoteVideoTrack
+    )
     return Unmanaged.passRetained(delegate).toOpaque()
 }
 

crates/live_kit_client/src/prod.rs 🔗

@@ -11,6 +11,7 @@ use futures::{
 pub use media::core_video::CVImageBuffer;
 use media::core_video::CVImageBufferRef;
 use parking_lot::Mutex;
+use postage::watch;
 use std::{
     ffi::c_void,
     sync::{Arc, Weak},
@@ -19,6 +20,7 @@ use std::{
 extern "C" {
     fn LKRoomDelegateCreate(
         callback_data: *mut c_void,
+        on_did_disconnect: extern "C" fn(callback_data: *mut c_void),
         on_did_subscribe_to_remote_video_track: extern "C" fn(
             callback_data: *mut c_void,
             publisher_id: CFStringRef,
@@ -75,8 +77,18 @@ extern "C" {
 
 pub type Sid = String;
 
+#[derive(Clone, Eq, PartialEq)]
+pub enum ConnectionState {
+    Disconnected,
+    Connected { url: String, token: String },
+}
+
 pub struct Room {
     native_room: *const c_void,
+    connection: Mutex<(
+        watch::Sender<ConnectionState>,
+        watch::Receiver<ConnectionState>,
+    )>,
     remote_video_track_subscribers: Mutex<Vec<mpsc::UnboundedSender<RemoteVideoTrackUpdate>>>,
     _delegate: RoomDelegate,
 }
@@ -87,13 +99,18 @@ impl Room {
             let delegate = RoomDelegate::new(weak_room.clone());
             Self {
                 native_room: unsafe { LKRoomCreate(delegate.native_delegate) },
+                connection: Mutex::new(watch::channel_with(ConnectionState::Disconnected)),
                 remote_video_track_subscribers: Default::default(),
                 _delegate: delegate,
             }
         })
     }
 
-    pub fn connect(&self, url: &str, token: &str) -> impl Future<Output = Result<()>> {
+    pub fn status(&self) -> watch::Receiver<ConnectionState> {
+        self.connection.lock().1.clone()
+    }
+
+    pub fn connect(self: &Arc<Self>, url: &str, token: &str) -> impl Future<Output = Result<()>> {
         let url = CFString::new(url);
         let token = CFString::new(token);
         let (did_connect, tx, rx) = Self::build_done_callback();
@@ -107,7 +124,23 @@ impl Room {
             )
         }
 
-        async { rx.await.unwrap().context("error connecting to room") }
+        let this = self.clone();
+        let url = url.to_string();
+        let token = token.to_string();
+        async move {
+            match rx.await.unwrap().context("error connecting to room") {
+                Ok(()) => {
+                    *this.connection.lock().0.borrow_mut() =
+                        ConnectionState::Connected { url, token };
+                    Ok(())
+                }
+                Err(err) => Err(err),
+            }
+        }
+    }
+
+    fn did_disconnect(&self) {
+        *self.connection.lock().0.borrow_mut() = ConnectionState::Disconnected;
     }
 
     pub fn display_sources(self: &Arc<Self>) -> impl Future<Output = Result<Vec<MacOSDisplay>>> {
@@ -265,6 +298,7 @@ impl RoomDelegate {
         let native_delegate = unsafe {
             LKRoomDelegateCreate(
                 weak_room as *mut c_void,
+                Self::on_did_disconnect,
                 Self::on_did_subscribe_to_remote_video_track,
                 Self::on_did_unsubscribe_from_remote_video_track,
             )
@@ -275,6 +309,14 @@ impl RoomDelegate {
         }
     }
 
+    extern "C" fn on_did_disconnect(room: *mut c_void) {
+        let room = unsafe { Weak::from_raw(room as *mut Room) };
+        if let Some(room) = room.upgrade() {
+            room.did_disconnect();
+        }
+        let _ = Weak::into_raw(room);
+    }
+
     extern "C" fn on_did_subscribe_to_remote_video_track(
         room: *mut c_void,
         publisher_id: CFStringRef,

crates/live_kit_client/src/test.rs 🔗

@@ -7,7 +7,8 @@ use lazy_static::lazy_static;
 use live_kit_server::token;
 use media::core_video::CVImageBuffer;
 use parking_lot::Mutex;
-use std::{future::Future, sync::Arc};
+use postage::watch;
+use std::{future::Future, mem, sync::Arc};
 
 lazy_static! {
     static ref SERVERS: Mutex<HashMap<String, Arc<TestServer>>> = Default::default();
@@ -145,6 +146,16 @@ impl TestServer {
         Ok(())
     }
 
+    pub async fn disconnect_client(&self, client_identity: String) {
+        self.background.simulate_random_delay().await;
+        let mut server_rooms = self.rooms.lock();
+        for room in server_rooms.values_mut() {
+            if let Some(room) = room.client_rooms.remove(&client_identity) {
+                *room.0.lock().connection.0.borrow_mut() = ConnectionState::Disconnected;
+            }
+        }
+    }
+
     async fn publish_video_track(&self, token: String, local_track: LocalVideoTrack) -> Result<()> {
         self.background.simulate_random_delay().await;
         let claims = live_kit_server::token::validate(&token, &self.secret_key)?;
@@ -227,7 +238,10 @@ impl live_kit_server::api::Client for TestApiClient {
 pub type Sid = String;
 
 struct RoomState {
-    connection: Option<ConnectionState>,
+    connection: (
+        watch::Sender<ConnectionState>,
+        watch::Receiver<ConnectionState>,
+    ),
     display_sources: Vec<MacOSDisplay>,
     video_track_updates: (
         async_broadcast::Sender<RemoteVideoTrackUpdate>,
@@ -235,9 +249,10 @@ struct RoomState {
     ),
 }
 
-struct ConnectionState {
-    url: String,
-    token: String,
+#[derive(Clone, Eq, PartialEq)]
+pub enum ConnectionState {
+    Disconnected,
+    Connected { url: String, token: String },
 }
 
 pub struct Room(Mutex<RoomState>);
@@ -245,12 +260,16 @@ pub struct Room(Mutex<RoomState>);
 impl Room {
     pub fn new() -> Arc<Self> {
         Arc::new(Self(Mutex::new(RoomState {
-            connection: None,
+            connection: watch::channel_with(ConnectionState::Disconnected),
             display_sources: Default::default(),
             video_track_updates: async_broadcast::broadcast(128),
         })))
     }
 
+    pub fn status(&self) -> watch::Receiver<ConnectionState> {
+        self.0.lock().connection.1.clone()
+    }
+
     pub fn connect(self: &Arc<Self>, url: &str, token: &str) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         let url = url.to_string();
@@ -258,7 +277,7 @@ impl Room {
         async move {
             let server = TestServer::get(&url)?;
             server.join_room(token.clone(), this.clone()).await?;
-            this.0.lock().connection = Some(ConnectionState { url, token });
+            *this.0.lock().connection.0.borrow_mut() = ConnectionState::Connected { url, token };
             Ok(())
         }
     }
@@ -301,32 +320,30 @@ impl Room {
     }
 
     fn test_server(&self) -> Arc<TestServer> {
-        let this = self.0.lock();
-        let connection = this
-            .connection
-            .as_ref()
-            .expect("must be connected to call this method");
-        TestServer::get(&connection.url).unwrap()
+        match self.0.lock().connection.1.borrow().clone() {
+            ConnectionState::Disconnected => panic!("must be connected to call this method"),
+            ConnectionState::Connected { url, .. } => TestServer::get(&url).unwrap(),
+        }
     }
 
     fn token(&self) -> String {
-        self.0
-            .lock()
-            .connection
-            .as_ref()
-            .expect("must be connected to call this method")
-            .token
-            .clone()
+        match self.0.lock().connection.1.borrow().clone() {
+            ConnectionState::Disconnected => panic!("must be connected to call this method"),
+            ConnectionState::Connected { token, .. } => token,
+        }
     }
 }
 
 impl Drop for Room {
     fn drop(&mut self) {
-        if let Some(connection) = self.0.lock().connection.take() {
-            if let Ok(server) = TestServer::get(&connection.token) {
+        if let ConnectionState::Connected { token, .. } = mem::replace(
+            &mut *self.0.lock().connection.0.borrow_mut(),
+            ConnectionState::Disconnected,
+        ) {
+            if let Ok(server) = TestServer::get(&token) {
                 let background = server.background.clone();
                 background
-                    .spawn(async move { server.leave_room(connection.token).await.unwrap() })
+                    .spawn(async move { server.leave_room(token).await.unwrap() })
                     .detach();
             }
         }

crates/live_kit_server/src/api.rs 🔗

@@ -86,7 +86,7 @@ impl Client for LiveKitClient {
     }
 
     async fn create_room(&self, name: String) -> Result<()> {
-        let x: proto::Room = self
+        let _: proto::Room = self
             .request(
                 "twirp/livekit.RoomService/CreateRoom",
                 token::VideoGrant {
@@ -99,7 +99,6 @@ impl Client for LiveKitClient {
                 },
             )
             .await?;
-        dbg!(x);
         Ok(())
     }