Implement calling contacts into your current channel

Max Brunsfeld and Mikayla created

Co-authored-by: Mikayla <mikayla@zed.dev>

Change summary

crates/call/src/call.rs                  |  8 ++
crates/call/src/room.rs                  | 81 ++++++++++++-------------
crates/collab/src/db.rs                  | 36 +++++++++-
crates/collab/src/db/tests.rs            | 25 +------
crates/collab/src/rpc.rs                 | 40 ++++++++----
crates/collab/src/tests/channel_tests.rs | 74 +++++++++++++++++++++++
crates/collab_ui/src/collab_panel.rs     |  7 -
crates/rpc/proto/zed.proto               |  3 
8 files changed, 187 insertions(+), 87 deletions(-)

Detailed changes

crates/call/src/call.rs 🔗

@@ -6,7 +6,9 @@ use std::sync::Arc;
 
 use anyhow::{anyhow, Result};
 use call_settings::CallSettings;
-use client::{proto, ClickhouseEvent, Client, TelemetrySettings, TypedEnvelope, User, UserStore};
+use client::{
+    proto, ChannelId, ClickhouseEvent, Client, TelemetrySettings, TypedEnvelope, User, UserStore,
+};
 use collections::HashSet;
 use futures::{future::Shared, FutureExt};
 use postage::watch;
@@ -75,6 +77,10 @@ impl ActiveCall {
         }
     }
 
+    pub fn channel_id(&self, cx: &AppContext) -> Option<ChannelId> {
+        self.room()?.read(cx).channel_id()
+    }
+
     async fn handle_incoming_call(
         this: ModelHandle<Self>,
         envelope: TypedEnvelope<proto::IncomingCall>,

crates/call/src/room.rs 🔗

@@ -274,26 +274,13 @@ impl Room {
         user_store: ModelHandle<UserStore>,
         cx: &mut AppContext,
     ) -> Task<Result<ModelHandle<Self>>> {
-        cx.spawn(|mut cx| async move {
-            let response = client.request(proto::JoinChannel { channel_id }).await?;
-            let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
-            let room = cx.add_model(|cx| {
-                Self::new(
-                    room_proto.id,
-                    Some(channel_id),
-                    response.live_kit_connection_info,
-                    client,
-                    user_store,
-                    cx,
-                )
-            });
-
-            room.update(&mut cx, |room, cx| {
-                room.apply_room_update(room_proto, cx)?;
-                anyhow::Ok(())
-            })?;
-
-            Ok(room)
+        cx.spawn(|cx| async move {
+            Self::from_join_response(
+                client.request(proto::JoinChannel { channel_id }).await?,
+                client,
+                user_store,
+                cx,
+            )
         })
     }
 
@@ -303,30 +290,42 @@ impl Room {
         user_store: ModelHandle<UserStore>,
         cx: &mut AppContext,
     ) -> Task<Result<ModelHandle<Self>>> {
-        let room_id = call.room_id;
-        cx.spawn(|mut cx| async move {
-            let response = client.request(proto::JoinRoom { id: room_id }).await?;
-            let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
-            let room = cx.add_model(|cx| {
-                Self::new(
-                    room_id,
-                    None,
-                    response.live_kit_connection_info,
-                    client,
-                    user_store,
-                    cx,
-                )
-            });
-            room.update(&mut cx, |room, cx| {
-                room.leave_when_empty = true;
-                room.apply_room_update(room_proto, cx)?;
-                anyhow::Ok(())
-            })?;
-
-            Ok(room)
+        let id = call.room_id;
+        cx.spawn(|cx| async move {
+            Self::from_join_response(
+                client.request(proto::JoinRoom { id }).await?,
+                client,
+                user_store,
+                cx,
+            )
         })
     }
 
+    fn from_join_response(
+        response: proto::JoinRoomResponse,
+        client: Arc<Client>,
+        user_store: ModelHandle<UserStore>,
+        mut cx: AsyncAppContext,
+    ) -> Result<ModelHandle<Self>> {
+        let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
+        let room = cx.add_model(|cx| {
+            Self::new(
+                room_proto.id,
+                response.channel_id,
+                response.live_kit_connection_info,
+                client,
+                user_store,
+                cx,
+            )
+        });
+        room.update(&mut cx, |room, cx| {
+            room.leave_when_empty = room.channel_id.is_none();
+            room.apply_room_update(room_proto, cx)?;
+            anyhow::Ok(())
+        })?;
+        Ok(room)
+    }
+
     fn should_leave(&self) -> bool {
         self.leave_when_empty
             && self.pending_room_update.is_none()

crates/collab/src/db.rs 🔗

@@ -1376,15 +1376,27 @@ impl Database {
         &self,
         room_id: RoomId,
         user_id: UserId,
-        channel_id: Option<ChannelId>,
         connection: ConnectionId,
     ) -> Result<RoomGuard<JoinRoom>> {
         self.room_transaction(room_id, |tx| async move {
+            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+            enum QueryChannelId {
+                ChannelId,
+            }
+            let channel_id: Option<ChannelId> = room::Entity::find()
+                .select_only()
+                .column(room::Column::ChannelId)
+                .filter(room::Column::Id.eq(room_id))
+                .into_values::<_, QueryChannelId>()
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such room"))?;
+
             if let Some(channel_id) = channel_id {
                 self.check_user_is_channel_member(channel_id, user_id, &*tx)
                     .await?;
 
-                room_participant::ActiveModel {
+                room_participant::Entity::insert_many([room_participant::ActiveModel {
                     room_id: ActiveValue::set(room_id),
                     user_id: ActiveValue::set(user_id),
                     answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
@@ -1392,15 +1404,23 @@ impl Database {
                         connection.owner_id as i32,
                     ))),
                     answering_connection_lost: ActiveValue::set(false),
-                    // Redundant for the channel join use case, used for channel and call invitations
                     calling_user_id: ActiveValue::set(user_id),
                     calling_connection_id: ActiveValue::set(connection.id as i32),
                     calling_connection_server_id: ActiveValue::set(Some(ServerId(
                         connection.owner_id as i32,
                     ))),
                     ..Default::default()
-                }
-                .insert(&*tx)
+                }])
+                .on_conflict(
+                    OnConflict::columns([room_participant::Column::UserId])
+                        .update_columns([
+                            room_participant::Column::AnsweringConnectionId,
+                            room_participant::Column::AnsweringConnectionServerId,
+                            room_participant::Column::AnsweringConnectionLost,
+                        ])
+                        .to_owned(),
+                )
+                .exec(&*tx)
                 .await?;
             } else {
                 let result = room_participant::Entity::update_many()
@@ -4053,6 +4073,12 @@ impl<T> DerefMut for RoomGuard<T> {
     }
 }
 
+impl<T> RoomGuard<T> {
+    pub fn into_inner(self) -> T {
+        self.data
+    }
+}
+
 #[derive(Debug, Serialize, Deserialize)]
 pub struct NewUserParams {
     pub github_login: String,

crates/collab/src/db/tests.rs 🔗

@@ -494,14 +494,9 @@ test_both_dbs!(
         )
         .await
         .unwrap();
-        db.join_room(
-            room_id,
-            user2.user_id,
-            None,
-            ConnectionId { owner_id, id: 1 },
-        )
-        .await
-        .unwrap();
+        db.join_room(room_id, user2.user_id, ConnectionId { owner_id, id: 1 })
+            .await
+            .unwrap();
         assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0);
 
         db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[])
@@ -1113,12 +1108,7 @@ test_both_dbs!(
 
         // can join a room with membership to its channel
         let joined_room = db
-            .join_room(
-                room_1,
-                user_1,
-                Some(channel_1),
-                ConnectionId { owner_id, id: 1 },
-            )
+            .join_room(room_1, user_1, ConnectionId { owner_id, id: 1 })
             .await
             .unwrap();
         assert_eq!(joined_room.room.participants.len(), 1);
@@ -1126,12 +1116,7 @@ test_both_dbs!(
         drop(joined_room);
         // cannot join a room without membership to its channel
         assert!(db
-            .join_room(
-                room_1,
-                user_2,
-                Some(channel_1),
-                ConnectionId { owner_id, id: 1 }
-            )
+            .join_room(room_1, user_2, ConnectionId { owner_id, id: 1 })
             .await
             .is_err());
     }

crates/collab/src/rpc.rs 🔗

@@ -930,16 +930,26 @@ async fn join_room(
     session: Session,
 ) -> Result<()> {
     let room_id = RoomId::from_proto(request.id);
-    let room = {
+    let joined_room = {
         let room = session
             .db()
             .await
-            .join_room(room_id, session.user_id, None, session.connection_id)
+            .join_room(room_id, session.user_id, session.connection_id)
             .await?;
         room_updated(&room.room, &session.peer);
-        room.room.clone()
+        room.into_inner()
     };
 
+    if let Some(channel_id) = joined_room.channel_id {
+        channel_updated(
+            channel_id,
+            &joined_room.room,
+            &joined_room.channel_members,
+            &session.peer,
+            &*session.connection_pool().await,
+        )
+    }
+
     for connection_id in session
         .connection_pool()
         .await
@@ -958,7 +968,10 @@ async fn join_room(
 
     let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
         if let Some(token) = live_kit
-            .room_token(&room.live_kit_room, &session.user_id.to_string())
+            .room_token(
+                &joined_room.room.live_kit_room,
+                &session.user_id.to_string(),
+            )
             .trace_err()
         {
             Some(proto::LiveKitConnectionInfo {
@@ -973,7 +986,8 @@ async fn join_room(
     };
 
     response.send(proto::JoinRoomResponse {
-        room: Some(room),
+        room: Some(joined_room.room),
+        channel_id: joined_room.channel_id.map(|id| id.to_proto()),
         live_kit_connection_info,
     })?;
 
@@ -1151,9 +1165,11 @@ async fn rejoin_room(
             }
         }
 
-        room = mem::take(&mut rejoined_room.room);
+        let rejoined_room = rejoined_room.into_inner();
+
+        room = rejoined_room.room;
         channel_id = rejoined_room.channel_id;
-        channel_members = mem::take(&mut rejoined_room.channel_members);
+        channel_members = rejoined_room.channel_members;
     }
 
     if let Some(channel_id) = channel_id {
@@ -2421,12 +2437,7 @@ async fn join_channel(
         let room_id = db.room_id_for_channel(channel_id).await?;
 
         let joined_room = db
-            .join_room(
-                room_id,
-                session.user_id,
-                Some(channel_id),
-                session.connection_id,
-            )
+            .join_room(room_id, session.user_id, session.connection_id)
             .await?;
 
         let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
@@ -2445,12 +2456,13 @@ async fn join_channel(
 
         response.send(proto::JoinRoomResponse {
             room: Some(joined_room.room.clone()),
+            channel_id: joined_room.channel_id.map(|id| id.to_proto()),
             live_kit_connection_info,
         })?;
 
         room_updated(&joined_room.room, &session.peer);
 
-        joined_room.clone()
+        joined_room.into_inner()
     };
 
     channel_updated(

crates/collab/src/tests/channel_tests.rs 🔗

@@ -696,6 +696,80 @@ async fn test_channel_rename(
     );
 }
 
+#[gpui::test]
+async fn test_call_from_channel(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+    cx_c: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+    let client_c = server.create_client(cx_c, "user_c").await;
+    server
+        .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b)])
+        .await;
+
+    let channel_id = server
+        .make_channel(
+            "x",
+            (&client_a, cx_a),
+            &mut [(&client_b, cx_b), (&client_c, cx_c)],
+        )
+        .await;
+
+    let active_call_a = cx_a.read(ActiveCall::global);
+    let active_call_b = cx_b.read(ActiveCall::global);
+
+    active_call_a
+        .update(cx_a, |call, cx| call.join_channel(channel_id, cx))
+        .await
+        .unwrap();
+
+    // Client A calls client B while in the channel.
+    active_call_a
+        .update(cx_a, |call, cx| {
+            call.invite(client_b.user_id().unwrap(), None, cx)
+        })
+        .await
+        .unwrap();
+
+    // Client B accepts the call.
+    deterministic.run_until_parked();
+    active_call_b
+        .update(cx_b, |call, cx| call.accept_incoming(cx))
+        .await
+        .unwrap();
+
+    // Client B sees that they are now in the channel
+    deterministic.run_until_parked();
+    active_call_b.read_with(cx_b, |call, cx| {
+        assert_eq!(call.channel_id(cx), Some(channel_id));
+    });
+    client_b.channel_store().read_with(cx_b, |channels, _| {
+        assert_participants_eq(
+            channels.channel_participants(channel_id),
+            &[client_a.user_id().unwrap(), client_b.user_id().unwrap()],
+        );
+    });
+
+    // Clients A and C also see that client B is in the channel.
+    client_a.channel_store().read_with(cx_a, |channels, _| {
+        assert_participants_eq(
+            channels.channel_participants(channel_id),
+            &[client_a.user_id().unwrap(), client_b.user_id().unwrap()],
+        );
+    });
+    client_c.channel_store().read_with(cx_c, |channels, _| {
+        assert_participants_eq(
+            channels.channel_participants(channel_id),
+            &[client_a.user_id().unwrap(), client_b.user_id().unwrap()],
+        );
+    });
+}
+
 #[derive(Debug, PartialEq)]
 struct ExpectedChannel {
     depth: usize,

crates/collab_ui/src/collab_panel.rs 🔗

@@ -1183,11 +1183,8 @@ impl CollabPanel {
         let text = match section {
             Section::ActiveCall => {
                 let channel_name = iife!({
-                    let channel_id = ActiveCall::global(cx)
-                        .read(cx)
-                        .room()?
-                        .read(cx)
-                        .channel_id()?;
+                    let channel_id = ActiveCall::global(cx).read(cx).channel_id(cx)?;
+
                     let name = self
                         .channel_store
                         .read(cx)

crates/rpc/proto/zed.proto 🔗

@@ -176,7 +176,8 @@ message JoinRoom {
 
 message JoinRoomResponse {
     Room room = 1;
-    optional LiveKitConnectionInfo live_kit_connection_info = 2;
+    optional uint64 channel_id = 2;
+    optional LiveKitConnectionInfo live_kit_connection_info = 3;
 }
 
 message RejoinRoom {