Add the ability to jump between channels while in a channel

Mikayla Maki created

Change summary

crates/call/src/call.rs                  |  6 +++
crates/client/src/client.rs              |  6 ++
crates/collab/src/db.rs                  | 29 ++++++++++++++++
crates/collab/src/rpc.rs                 | 25 ++++++++++++-
crates/collab/src/tests/channel_tests.rs | 47 ++++++++++++++++++++++++++
5 files changed, 110 insertions(+), 3 deletions(-)

Detailed changes

crates/call/src/call.rs 🔗

@@ -279,15 +279,21 @@ impl ActiveCall {
         channel_id: u64,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<()>> {
+        let leave_room;
         if let Some(room) = self.room().cloned() {
             if room.read(cx).channel_id() == Some(channel_id) {
                 return Task::ready(Ok(()));
+            } else {
+                leave_room = room.update(cx, |room, cx| room.leave(cx));
             }
+        } else {
+            leave_room = Task::ready(Ok(()));
         }
 
         let join = Room::join_channel(channel_id, self.client.clone(), self.user_store.clone(), cx);
 
         cx.spawn(|this, mut cx| async move {
+            leave_room.await?;
             let room = join.await?;
             this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))
                 .await?;

crates/client/src/client.rs 🔗

@@ -540,6 +540,7 @@ impl Client {
         }
     }
 
+    #[track_caller]
     pub fn add_message_handler<M, E, H, F>(
         self: &Arc<Self>,
         model: ModelHandle<E>,
@@ -575,8 +576,11 @@ impl Client {
             }),
         );
         if prev_handler.is_some() {
+            let location = std::panic::Location::caller();
             panic!(
-                "registered handler for the same message {} twice",
+                "{}:{} registered handler for the same message {} twice",
+                location.file(),
+                location.line(),
                 std::any::type_name::<M>()
             );
         }

crates/collab/src/db.rs 🔗

@@ -1342,6 +1342,35 @@ impl Database {
         .await
     }
 
+    pub async fn is_current_room_different_channel(
+        &self,
+        user_id: UserId,
+        channel_id: ChannelId,
+    ) -> Result<bool> {
+        self.transaction(|tx| async move {
+            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+            enum QueryAs {
+                ChannelId,
+            }
+
+            let channel_id_model: Option<ChannelId> = room_participant::Entity::find()
+                .select_only()
+                .column_as(room::Column::ChannelId, QueryAs::ChannelId)
+                .inner_join(room::Entity)
+                .filter(room_participant::Column::UserId.eq(user_id))
+                .into_values::<_, QueryAs>()
+                .one(&*tx)
+                .await?;
+
+            let result = channel_id_model
+                .map(|channel_id_model| channel_id_model != channel_id)
+                .unwrap_or(false);
+
+            Ok(result)
+        })
+        .await
+    }
+
     pub async fn join_room(
         &self,
         room_id: RoomId,

crates/collab/src/rpc.rs 🔗

@@ -2276,6 +2276,14 @@ async fn join_channel(
 
     let joined_room = {
         let db = session.db().await;
+
+        if db
+            .is_current_room_different_channel(session.user_id, channel_id)
+            .await?
+        {
+            leave_room_for_session_with_guard(&session, &db).await?;
+        }
+
         let room_id = db.room_id_for_channel(channel_id).await?;
 
         let joined_room = db
@@ -2531,6 +2539,14 @@ fn channel_updated(
 
 async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
     let db = session.db().await;
+    update_user_contacts_with_guard(user_id, session, &db).await
+}
+
+async fn update_user_contacts_with_guard(
+    user_id: UserId,
+    session: &Session,
+    db: &DbHandle,
+) -> Result<()> {
     let contacts = db.get_contacts(user_id).await?;
     let busy = db.is_user_busy(user_id).await?;
 
@@ -2564,6 +2580,11 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()>
 }
 
 async fn leave_room_for_session(session: &Session) -> Result<()> {
+    let db = session.db().await;
+    leave_room_for_session_with_guard(session, &db).await
+}
+
+async fn leave_room_for_session_with_guard(session: &Session, db: &DbHandle) -> Result<()> {
     let mut contacts_to_update = HashSet::default();
 
     let room_id;
@@ -2574,7 +2595,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
     let channel_members;
     let channel_id;
 
-    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
+    if let Some(mut left_room) = db.leave_room(session.connection_id).await? {
         contacts_to_update.insert(session.user_id);
 
         for project in left_room.left_projects.values() {
@@ -2624,7 +2645,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
     }
 
     for contact_user_id in contacts_to_update {
-        update_user_contacts(contact_user_id, &session).await?;
+        update_user_contacts_with_guard(contact_user_id, &session, db).await?;
     }
 
     if let Some(live_kit) = session.live_kit_client.as_ref() {

crates/collab/src/tests/channel_tests.rs 🔗

@@ -304,3 +304,50 @@ async fn test_channel_room(
         }
     );
 }
+
+#[gpui::test]
+async fn test_channel_jumping(deterministic: Arc<Deterministic>, cx_a: &mut TestAppContext) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+
+    let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await;
+    let rust_id = server
+        .make_channel("rust", (&client_a, cx_a), &mut [])
+        .await;
+
+    let active_call_a = cx_a.read(ActiveCall::global);
+
+    active_call_a
+        .update(cx_a, |active_call, cx| active_call.join_channel(zed_id, cx))
+        .await
+        .unwrap();
+
+    // Give everything a chance to observe user A joining
+    deterministic.run_until_parked();
+
+    client_a.channel_store().read_with(cx_a, |channels, _| {
+        assert_participants_eq(
+            channels.channel_participants(zed_id),
+            &[client_a.user_id().unwrap()],
+        );
+        assert_participants_eq(channels.channel_participants(rust_id), &[]);
+    });
+
+    active_call_a
+        .update(cx_a, |active_call, cx| {
+            active_call.join_channel(rust_id, cx)
+        })
+        .await
+        .unwrap();
+
+    deterministic.run_until_parked();
+
+    client_a.channel_store().read_with(cx_a, |channels, _| {
+        assert_participants_eq(channels.channel_participants(zed_id), &[]);
+        assert_participants_eq(
+            channels.channel_participants(rust_id),
+            &[client_a.user_id().unwrap()],
+        );
+    });
+}