fix rejoin after quit (#10100)

Conrad Irwin created

Release Notes:

- collab: Fixed rejoining channels quickly after a restart

Change summary

crates/collab/src/db/queries/rooms.rs        | 88 +++++++++++++--------
crates/collab/src/rpc.rs                     | 23 ++++-
crates/collab/src/tests/following_tests.rs   |  2 
crates/collab/src/tests/integration_tests.rs | 20 ++++
crates/collab/src/tests/test_server.rs       |  5 
crates/gpui/src/executor.rs                  | 14 +++
crates/gpui/src/platform/test/dispatcher.rs  | 10 ++
crates/gpui/src/platform/test/platform.rs    |  5 +
crates/gpui/src/platform/test/window.rs      |  6 
9 files changed, 124 insertions(+), 49 deletions(-)

Detailed changes

crates/collab/src/db/queries/rooms.rs 🔗

@@ -349,6 +349,17 @@ impl Database {
         .await
     }
 
+    pub async fn stale_room_connection(&self, user_id: UserId) -> Result<Option<ConnectionId>> {
+        self.transaction(|tx| async move {
+            let participant = room_participant::Entity::find()
+                .filter(room_participant::Column::UserId.eq(user_id))
+                .one(&*tx)
+                .await?;
+            Ok(participant.and_then(|p| p.answering_connection()))
+        })
+        .await
+    }
+
     async fn get_next_participant_index_internal(
         &self,
         room_id: RoomId,
@@ -403,39 +414,50 @@ impl Database {
             .get_next_participant_index_internal(room_id, tx)
             .await?;
 
-        room_participant::Entity::insert_many([room_participant::ActiveModel {
-            room_id: ActiveValue::set(room_id),
-            user_id: ActiveValue::set(user_id),
-            answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
-            answering_connection_server_id: ActiveValue::set(Some(ServerId(
-                connection.owner_id as i32,
-            ))),
-            answering_connection_lost: ActiveValue::set(false),
-            calling_user_id: ActiveValue::set(user_id),
-            calling_connection_id: ActiveValue::set(connection.id as i32),
-            calling_connection_server_id: ActiveValue::set(Some(ServerId(
-                connection.owner_id as i32,
-            ))),
-            participant_index: ActiveValue::Set(Some(participant_index)),
-            role: ActiveValue::set(Some(role)),
-            id: ActiveValue::NotSet,
-            location_kind: ActiveValue::NotSet,
-            location_project_id: ActiveValue::NotSet,
-            initial_project_id: ActiveValue::NotSet,
-        }])
-        .on_conflict(
-            OnConflict::columns([room_participant::Column::UserId])
-                .update_columns([
-                    room_participant::Column::AnsweringConnectionId,
-                    room_participant::Column::AnsweringConnectionServerId,
-                    room_participant::Column::AnsweringConnectionLost,
-                    room_participant::Column::ParticipantIndex,
-                    room_participant::Column::Role,
-                ])
-                .to_owned(),
-        )
-        .exec(tx)
-        .await?;
+        // If someone has been invited into the room, accept the invite instead of inserting
+        let result = room_participant::Entity::update_many()
+            .filter(
+                Condition::all()
+                    .add(room_participant::Column::RoomId.eq(room_id))
+                    .add(room_participant::Column::UserId.eq(user_id))
+                    .add(room_participant::Column::AnsweringConnectionId.is_null()),
+            )
+            .set(room_participant::ActiveModel {
+                participant_index: ActiveValue::Set(Some(participant_index)),
+                answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
+                answering_connection_server_id: ActiveValue::set(Some(ServerId(
+                    connection.owner_id as i32,
+                ))),
+                answering_connection_lost: ActiveValue::set(false),
+                ..Default::default()
+            })
+            .exec(tx)
+            .await?;
+
+        if result.rows_affected == 0 {
+            room_participant::Entity::insert(room_participant::ActiveModel {
+                room_id: ActiveValue::set(room_id),
+                user_id: ActiveValue::set(user_id),
+                answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
+                answering_connection_server_id: ActiveValue::set(Some(ServerId(
+                    connection.owner_id as i32,
+                ))),
+                answering_connection_lost: ActiveValue::set(false),
+                calling_user_id: ActiveValue::set(user_id),
+                calling_connection_id: ActiveValue::set(connection.id as i32),
+                calling_connection_server_id: ActiveValue::set(Some(ServerId(
+                    connection.owner_id as i32,
+                ))),
+                participant_index: ActiveValue::Set(Some(participant_index)),
+                role: ActiveValue::set(Some(role)),
+                id: ActiveValue::NotSet,
+                location_kind: ActiveValue::NotSet,
+                location_project_id: ActiveValue::NotSet,
+                initial_project_id: ActiveValue::NotSet,
+            })
+            .exec(tx)
+            .await?;
+        }
 
         let (channel, room) = self.get_channel_room(room_id, &tx).await?;
         let channel = channel.ok_or_else(|| anyhow!("no channel for room"))?;

crates/collab/src/rpc.rs 🔗

@@ -1203,7 +1203,7 @@ async fn connection_lost(
         _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
             if let Some(session) = session.for_user() {
                 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
-                leave_room_for_session(&session).await.trace_err();
+                leave_room_for_session(&session, session.connection_id).await.trace_err();
                 leave_channel_buffers_for_session(&session)
                     .await
                     .trace_err();
@@ -1539,7 +1539,7 @@ async fn leave_room(
     response: Response<proto::LeaveRoom>,
     session: UserSession,
 ) -> Result<()> {
-    leave_room_for_session(&session).await?;
+    leave_room_for_session(&session, session.connection_id).await?;
     response.send(proto::Ack {})?;
     Ok(())
 }
@@ -3023,8 +3023,19 @@ async fn join_channel_internal(
     session: UserSession,
 ) -> Result<()> {
     let joined_room = {
-        leave_room_for_session(&session).await?;
-        let db = session.db().await;
+        let mut db = session.db().await;
+        // If zed quits without leaving the room, and the user re-opens zed before the
+        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
+        // room they were in.
+        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
+            tracing::info!(
+                stale_connection_id = %connection,
+                "cleaning up stale connection",
+            );
+            drop(db);
+            leave_room_for_session(&session, connection).await?;
+            db = session.db().await;
+        }
 
         let (joined_room, membership_updated, role) = db
             .join_channel(channel_id, session.user_id(), session.connection_id)
@@ -4199,7 +4210,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()>
     Ok(())
 }
 
-async fn leave_room_for_session(session: &UserSession) -> Result<()> {
+async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
     let mut contacts_to_update = HashSet::default();
 
     let room_id;
@@ -4209,7 +4220,7 @@ async fn leave_room_for_session(session: &UserSession) -> Result<()> {
     let room;
     let channel;
 
-    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
+    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
         contacts_to_update.insert(session.user_id());
 
         for project in left_room.left_projects.values() {

crates/collab/src/tests/following_tests.rs 🔗

@@ -2007,7 +2007,7 @@ async fn test_following_to_channel_notes_without_a_shared_project(
     });
 }
 
-async fn join_channel(
+pub(crate) async fn join_channel(
     channel_id: ChannelId,
     client: &TestClient,
     cx: &mut TestAppContext,

crates/collab/src/tests/integration_tests.rs 🔗

@@ -1,6 +1,9 @@
 use crate::{
     rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
-    tests::{channel_id, room_participants, rust_lang, RoomParticipants, TestClient, TestServer},
+    tests::{
+        channel_id, following_tests::join_channel, room_participants, rust_lang, RoomParticipants,
+        TestClient, TestServer,
+    },
 };
 use call::{room, ActiveCall, ParticipantLocation, Room};
 use client::{User, RECEIVE_TIMEOUT};
@@ -5914,7 +5917,7 @@ async fn test_right_click_menu_behind_collab_panel(cx: &mut TestAppContext) {
 
 #[gpui::test]
 async fn test_cmd_k_left(cx: &mut TestAppContext) {
-    let client = TestServer::start1(cx).await;
+    let (_, client) = TestServer::start1(cx).await;
     let (workspace, cx) = client.build_test_workspace(cx).await;
 
     cx.simulate_keystrokes("cmd-n");
@@ -5934,3 +5937,16 @@ async fn test_cmd_k_left(cx: &mut TestAppContext) {
         assert!(workspace.items(cx).collect::<Vec<_>>().len() == 2);
     });
 }
+
+#[gpui::test]
+async fn test_join_after_restart(cx1: &mut TestAppContext, cx2: &mut TestAppContext) {
+    let (mut server, client) = TestServer::start1(cx1).await;
+    let channel1 = server.make_public_channel("channel1", &client, cx1).await;
+    let channel2 = server.make_public_channel("channel2", &client, cx1).await;
+
+    join_channel(channel1, &client, cx1).await.unwrap();
+    drop(client);
+
+    let client2 = server.create_client(cx2, "user_a").await;
+    join_channel(channel2, &client2, cx2).await.unwrap();
+}

crates/collab/src/tests/test_server.rs 🔗

@@ -135,9 +135,10 @@ impl TestServer {
         (server, client_a, client_b, channel_id)
     }
 
-    pub async fn start1(cx: &mut TestAppContext) -> TestClient {
+    pub async fn start1(cx: &mut TestAppContext) -> (TestServer, TestClient) {
         let mut server = Self::start(cx.executor().clone()).await;
-        server.create_client(cx, "user_a").await
+        let client = server.create_client(cx, "user_a").await;
+        (server, client)
     }
 
     pub async fn reset(&self) {

crates/gpui/src/executor.rs 🔗

@@ -219,11 +219,17 @@ impl BackgroundExecutor {
                         if let Some(test) = self.dispatcher.as_test() {
                             if !test.parking_allowed() {
                                 let mut backtrace_message = String::new();
+                                let mut waiting_message = String::new();
                                 if let Some(backtrace) = test.waiting_backtrace() {
                                     backtrace_message =
                                         format!("\nbacktrace of waiting future:\n{:?}", backtrace);
                                 }
-                                panic!("parked with nothing left to run\n{:?}", backtrace_message)
+                                if let Some(waiting_hint) = test.waiting_hint() {
+                                    waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
+                                }
+                                panic!(
+                                    "parked with nothing left to run{waiting_message}{backtrace_message}",
+                                )
                             }
                         }
 
@@ -354,6 +360,12 @@ impl BackgroundExecutor {
         self.dispatcher.as_test().unwrap().forbid_parking();
     }
 
+    /// adds detail to the "parked with nothing let to run" message.
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn set_waiting_hint(&self, msg: Option<String>) {
+        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
+    }
+
     /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
     #[cfg(any(test, feature = "test-support"))]
     pub fn rng(&self) -> StdRng {

crates/gpui/src/platform/test/dispatcher.rs 🔗

@@ -36,6 +36,7 @@ struct TestDispatcherState {
     is_main_thread: bool,
     next_id: TestDispatcherId,
     allow_parking: bool,
+    waiting_hint: Option<String>,
     waiting_backtrace: Option<Backtrace>,
     deprioritized_task_labels: HashSet<TaskLabel>,
     block_on_ticks: RangeInclusive<usize>,
@@ -54,6 +55,7 @@ impl TestDispatcher {
             is_main_thread: true,
             next_id: TestDispatcherId(1),
             allow_parking: false,
+            waiting_hint: None,
             waiting_backtrace: None,
             deprioritized_task_labels: Default::default(),
             block_on_ticks: 0..=1000,
@@ -132,6 +134,14 @@ impl TestDispatcher {
         self.state.lock().allow_parking = false
     }
 
+    pub fn set_waiting_hint(&self, msg: Option<String>) {
+        self.state.lock().waiting_hint = msg
+    }
+
+    pub fn waiting_hint(&self) -> Option<String> {
+        self.state.lock().waiting_hint.clone()
+    }
+
     pub fn start_waiting(&self) {
         self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
     }

crates/gpui/src/platform/test/platform.rs 🔗

@@ -69,6 +69,7 @@ impl TestPlatform {
             .multiple_choice
             .pop_front()
             .expect("no pending multiple choice prompt");
+        self.background_executor().set_waiting_hint(None);
         tx.send(response_ix).ok();
     }
 
@@ -76,8 +77,10 @@ impl TestPlatform {
         !self.prompts.borrow().multiple_choice.is_empty()
     }
 
-    pub(crate) fn prompt(&self) -> oneshot::Receiver<usize> {
+    pub(crate) fn prompt(&self, msg: &str, detail: Option<&str>) -> oneshot::Receiver<usize> {
         let (tx, rx) = oneshot::channel();
+        self.background_executor()
+            .set_waiting_hint(Some(format!("PROMPT: {:?} {:?}", msg, detail)));
         self.prompts.borrow_mut().multiple_choice.push_back(tx);
         rx
     }

crates/gpui/src/platform/test/window.rs 🔗

@@ -159,8 +159,8 @@ impl PlatformWindow for TestWindow {
     fn prompt(
         &self,
         _level: crate::PromptLevel,
-        _msg: &str,
-        _detail: Option<&str>,
+        msg: &str,
+        detail: Option<&str>,
         _answers: &[&str],
     ) -> Option<futures::channel::oneshot::Receiver<usize>> {
         Some(
@@ -169,7 +169,7 @@ impl PlatformWindow for TestWindow {
                 .platform
                 .upgrade()
                 .expect("platform dropped")
-                .prompt(),
+                .prompt(msg, detail),
         )
     }