Implement joining a room and sending updates after people join/leave

Antonio Scandurra created

Change summary

crates/client/src/user.rs              | 31 ++++++---
crates/collab/src/integration_tests.rs | 39 +++++++++++-
crates/collab/src/rpc.rs               | 84 ++++++++++++++++++----------
crates/collab/src/rpc/store.rs         | 60 +++++++++++++++++--
crates/room/src/room.rs                | 57 +++++++++++++-----
crates/rpc/proto/zed.proto             |  8 +-
crates/rpc/src/proto.rs                |  3 +
7 files changed, 211 insertions(+), 71 deletions(-)

Detailed changes

crates/client/src/user.rs 🔗

@@ -1,9 +1,8 @@
-use crate::call::Call;
-
 use super::{http::HttpClient, proto, Client, Status, TypedEnvelope};
+use crate::call::Call;
 use anyhow::{anyhow, Context, Result};
 use collections::{hash_map::Entry, BTreeSet, HashMap, HashSet};
-use futures::{channel::mpsc, future, AsyncReadExt, Future, Stream, StreamExt};
+use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt};
 use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task};
 use postage::{sink::Sink, watch};
 use rpc::proto::{RequestMessage, UsersResponse};
@@ -68,7 +67,7 @@ pub struct UserStore {
     outgoing_contact_requests: Vec<Arc<User>>,
     pending_contact_requests: HashMap<u64, usize>,
     invite_info: Option<InviteInfo>,
-    incoming_calls: Vec<mpsc::UnboundedSender<Call>>,
+    incoming_call: (watch::Sender<Option<Call>>, watch::Receiver<Option<Call>>),
     client: Weak<Client>,
     http: Arc<dyn HttpClient>,
     _maintain_contacts: Task<()>,
@@ -119,6 +118,7 @@ impl UserStore {
             client.add_message_handler(cx.handle(), Self::handle_update_invite_info),
             client.add_message_handler(cx.handle(), Self::handle_show_contacts),
             client.add_request_handler(cx.handle(), Self::handle_incoming_call),
+            client.add_message_handler(cx.handle(), Self::handle_cancel_call),
         ];
         Self {
             users: Default::default(),
@@ -127,7 +127,7 @@ impl UserStore {
             incoming_contact_requests: Default::default(),
             outgoing_contact_requests: Default::default(),
             invite_info: None,
-            incoming_calls: Default::default(),
+            incoming_call: watch::channel(),
             client: Arc::downgrade(&client),
             update_contacts_tx,
             http,
@@ -219,21 +219,30 @@ impl UserStore {
                 .await?,
         };
         this.update(&mut cx, |this, _| {
-            this.incoming_calls
-                .retain(|tx| tx.unbounded_send(call.clone()).is_ok());
+            *this.incoming_call.0.borrow_mut() = Some(call);
         });
 
         Ok(proto::Ack {})
     }
 
+    async fn handle_cancel_call(
+        this: ModelHandle<Self>,
+        _: TypedEnvelope<proto::CancelCall>,
+        _: Arc<Client>,
+        mut cx: AsyncAppContext,
+    ) -> Result<()> {
+        this.update(&mut cx, |this, _| {
+            *this.incoming_call.0.borrow_mut() = None;
+        });
+        Ok(())
+    }
+
     pub fn invite_info(&self) -> Option<&InviteInfo> {
         self.invite_info.as_ref()
     }
 
-    pub fn incoming_calls(&mut self) -> impl 'static + Stream<Item = Call> {
-        let (tx, rx) = mpsc::unbounded();
-        self.incoming_calls.push(tx);
-        rx
+    pub fn incoming_call(&self) -> watch::Receiver<Option<Call>> {
+        self.incoming_call.1.clone()
     }
 
     async fn handle_update_contacts(

crates/collab/src/integration_tests.rs 🔗

@@ -98,18 +98,49 @@ async fn test_share_project_in_room(
     let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await;
     // room.publish_project(project_a.clone()).await.unwrap();
 
-    let mut incoming_calls_b = client_b
+    let mut incoming_call_b = client_b
         .user_store
-        .update(cx_b, |user, _| user.incoming_calls());
+        .update(cx_b, |user, _| user.incoming_call());
     room_a
         .update(cx_a, |room, cx| room.call(client_b.user_id().unwrap(), cx))
         .await
         .unwrap();
-    let call_b = incoming_calls_b.next().await.unwrap();
+    let call_b = incoming_call_b.next().await.unwrap().unwrap();
     let room_b = cx_b
-        .update(|cx| Room::join(call_b.room_id, client_b.clone(), cx))
+        .update(|cx| Room::join(&call_b, client_b.clone(), cx))
         .await
         .unwrap();
+    assert!(incoming_call_b.next().await.unwrap().is_none());
+    assert_eq!(
+        remote_participants(&room_a, &client_a, cx_a).await,
+        vec!["user_b"]
+    );
+    assert_eq!(
+        remote_participants(&room_b, &client_b, cx_b).await,
+        vec!["user_a"]
+    );
+
+    async fn remote_participants(
+        room: &ModelHandle<Room>,
+        client: &TestClient,
+        cx: &mut TestAppContext,
+    ) -> Vec<String> {
+        let users = room.update(cx, |room, cx| {
+            room.remote_participants()
+                .values()
+                .map(|participant| {
+                    client
+                        .user_store
+                        .update(cx, |users, cx| users.get_user(participant.user_id, cx))
+                })
+                .collect::<Vec<_>>()
+        });
+        let users = futures::future::try_join_all(users).await.unwrap();
+        users
+            .into_iter()
+            .map(|user| user.github_login.clone())
+            .collect()
+    }
 }
 
 #[gpui::test(iterations = 10)]

crates/collab/src/rpc.rs 🔗

@@ -152,6 +152,7 @@ impl Server {
         server
             .add_request_handler(Server::ping)
             .add_request_handler(Server::create_room)
+            .add_request_handler(Server::join_room)
             .add_request_handler(Server::call)
             .add_request_handler(Server::register_project)
             .add_request_handler(Server::unregister_project)
@@ -605,6 +606,26 @@ impl Server {
         Ok(())
     }
 
+    async fn join_room(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::JoinRoom>,
+        response: Response<proto::JoinRoom>,
+    ) -> Result<()> {
+        let room_id = request.payload.id;
+        let mut store = self.store().await;
+        let (room, recipient_ids) = store.join_room(room_id, request.sender_id)?;
+        for receiver_id in recipient_ids {
+            self.peer
+                .send(receiver_id, proto::CancelCall {})
+                .trace_err();
+        }
+        response.send(proto::JoinRoomResponse {
+            room: Some(room.clone()),
+        })?;
+        self.room_updated(room);
+        Ok(())
+    }
+
     async fn call(
         self: Arc<Server>,
         request: TypedEnvelope<proto::Call>,
@@ -612,34 +633,29 @@ impl Server {
     ) -> Result<()> {
         let to_user_id = UserId::from_proto(request.payload.to_user_id);
         let room_id = request.payload.room_id;
-        let (from_user_id, receiver_ids, room) =
-            self.store()
-                .await
-                .call(room_id, request.sender_id, to_user_id)?;
-        for participant in &room.participants {
-            self.peer
-                .send(
-                    ConnectionId(participant.peer_id),
-                    proto::RoomUpdated {
-                        room: Some(room.clone()),
-                    },
-                )
-                .trace_err();
-        }
-
-        let mut calls = receiver_ids
-            .into_iter()
-            .map(|receiver_id| {
-                self.peer.request(
-                    receiver_id,
-                    proto::IncomingCall {
-                        room_id,
-                        from_user_id: from_user_id.to_proto(),
-                        participant_user_ids: room.participants.iter().map(|p| p.user_id).collect(),
-                    },
-                )
-            })
-            .collect::<FuturesUnordered<_>>();
+        let mut calls = {
+            let mut store = self.store().await;
+            let (from_user_id, recipient_ids, room) =
+                store.call(room_id, request.sender_id, to_user_id)?;
+            self.room_updated(room);
+            recipient_ids
+                .into_iter()
+                .map(|recipient_id| {
+                    self.peer.request(
+                        recipient_id,
+                        proto::IncomingCall {
+                            room_id,
+                            from_user_id: from_user_id.to_proto(),
+                            participant_user_ids: room
+                                .participants
+                                .iter()
+                                .map(|p| p.user_id)
+                                .collect(),
+                        },
+                    )
+                })
+                .collect::<FuturesUnordered<_>>()
+        };
 
         while let Some(call_response) = calls.next().await {
             match call_response.as_ref() {
@@ -653,7 +669,16 @@ impl Server {
             }
         }
 
-        let room = self.store().await.call_failed(room_id, to_user_id)?;
+        {
+            let mut store = self.store().await;
+            let room = store.call_failed(room_id, to_user_id)?;
+            self.room_updated(&room);
+        }
+
+        Err(anyhow!("failed to ring call recipient"))?
+    }
+
+    fn room_updated(&self, room: &proto::Room) {
         for participant in &room.participants {
             self.peer
                 .send(
@@ -664,7 +689,6 @@ impl Server {
                 )
                 .trace_err();
         }
-        Err(anyhow!("failed to ring call recipient"))?
     }
 
     async fn register_project(

crates/collab/src/rpc/store.rs 🔗

@@ -25,7 +25,7 @@ pub struct Store {
 struct ConnectionState {
     user_id: UserId,
     admin: bool,
-    rooms: BTreeSet<RoomId>,
+    room: Option<RoomId>,
     projects: BTreeSet<ProjectId>,
     requested_projects: HashSet<ProjectId>,
     channels: HashSet<ChannelId>,
@@ -140,7 +140,7 @@ impl Store {
             ConnectionState {
                 user_id,
                 admin,
-                rooms: Default::default(),
+                room: Default::default(),
                 projects: Default::default(),
                 requested_projects: Default::default(),
                 channels: Default::default(),
@@ -333,6 +333,11 @@ impl Store {
             .connections
             .get_mut(&creator_connection_id)
             .ok_or_else(|| anyhow!("no such connection"))?;
+        anyhow::ensure!(
+            connection.room.is_none(),
+            "cannot participate in more than one room at once"
+        );
+
         let mut room = proto::Room::default();
         room.participants.push(proto::Participant {
             user_id: connection.user_id.to_proto(),
@@ -347,16 +352,57 @@ impl Store {
 
         let room_id = post_inc(&mut self.next_room_id);
         self.rooms.insert(room_id, room);
-        connection.rooms.insert(room_id);
+        connection.room = Some(room_id);
         Ok(room_id)
     }
 
+    pub fn join_room(
+        &mut self,
+        room_id: u64,
+        connection_id: ConnectionId,
+    ) -> Result<(&proto::Room, Vec<ConnectionId>)> {
+        let connection = self
+            .connections
+            .get_mut(&connection_id)
+            .ok_or_else(|| anyhow!("no such connection"))?;
+        anyhow::ensure!(
+            connection.room.is_none(),
+            "cannot participate in more than one room at once"
+        );
+
+        let user_id = connection.user_id;
+        let recipient_ids = self.connection_ids_for_user(user_id).collect::<Vec<_>>();
+
+        let room = self
+            .rooms
+            .get_mut(&room_id)
+            .ok_or_else(|| anyhow!("no such room"))?;
+        anyhow::ensure!(
+            room.pending_calls_to_user_ids.contains(&user_id.to_proto()),
+            anyhow!("no such room")
+        );
+        room.pending_calls_to_user_ids
+            .retain(|pending| *pending != user_id.to_proto());
+        room.participants.push(proto::Participant {
+            user_id: user_id.to_proto(),
+            peer_id: connection_id.0,
+            project_ids: Default::default(),
+            location: Some(proto::ParticipantLocation {
+                variant: Some(proto::participant_location::Variant::External(
+                    proto::participant_location::External {},
+                )),
+            }),
+        });
+
+        Ok((room, recipient_ids))
+    }
+
     pub fn call(
         &mut self,
         room_id: RoomId,
         from_connection_id: ConnectionId,
         to_user_id: UserId,
-    ) -> Result<(UserId, Vec<ConnectionId>, proto::Room)> {
+    ) -> Result<(UserId, Vec<ConnectionId>, &proto::Room)> {
         let from_user_id = self.user_id_for_connection(from_connection_id)?;
         let to_connection_ids = self.connection_ids_for_user(to_user_id).collect::<Vec<_>>();
         let room = self
@@ -377,17 +423,17 @@ impl Store {
         );
         room.pending_calls_to_user_ids.push(to_user_id.to_proto());
 
-        Ok((from_user_id, to_connection_ids, room.clone()))
+        Ok((from_user_id, to_connection_ids, room))
     }
 
-    pub fn call_failed(&mut self, room_id: RoomId, to_user_id: UserId) -> Result<proto::Room> {
+    pub fn call_failed(&mut self, room_id: RoomId, to_user_id: UserId) -> Result<&proto::Room> {
         let room = self
             .rooms
             .get_mut(&room_id)
             .ok_or_else(|| anyhow!("no such room"))?;
         room.pending_calls_to_user_ids
             .retain(|user_id| UserId::from_proto(*user_id) != to_user_id);
-        Ok(room.clone())
+        Ok(room)
     }
 
     pub fn register_project(

crates/room/src/room.rs 🔗

@@ -1,9 +1,9 @@
 mod participant;
 
 use anyhow::{anyhow, Result};
-use client::{proto, Client, PeerId};
+use client::{call::Call, proto, Client, PeerId, TypedEnvelope};
 use collections::HashMap;
-use gpui::{Entity, ModelContext, ModelHandle, MutableAppContext, Task};
+use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
 use participant::{LocalParticipant, ParticipantLocation, RemoteParticipant};
 use project::Project;
 use std::sync::Arc;
@@ -22,6 +22,7 @@ pub struct Room {
     local_participant: LocalParticipant,
     remote_participants: HashMap<PeerId, RemoteParticipant>,
     client: Arc<Client>,
+    _subscriptions: Vec<client::Subscription>,
 }
 
 impl Entity for Room {
@@ -40,40 +41,64 @@ impl Room {
     }
 
     pub fn join(
-        id: u64,
+        call: &Call,
         client: Arc<Client>,
         cx: &mut MutableAppContext,
     ) -> Task<Result<ModelHandle<Self>>> {
+        let room_id = call.room_id;
         cx.spawn(|mut cx| async move {
-            let response = client.request(proto::JoinRoom { id }).await?;
+            let response = client.request(proto::JoinRoom { id: room_id }).await?;
             let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
-            let room = cx.add_model(|cx| Self::new(id, client, cx));
-            room.update(&mut cx, |room, cx| room.apply_update(room_proto, cx))?;
+            let room = cx.add_model(|cx| Self::new(room_id, client, cx));
+            room.update(&mut cx, |room, cx| room.apply_room_update(room_proto, cx))?;
             Ok(room)
         })
     }
 
-    fn new(id: u64, client: Arc<Client>, _: &mut ModelContext<Self>) -> Self {
+    fn new(id: u64, client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
         Self {
             id,
             local_participant: LocalParticipant {
                 projects: Default::default(),
             },
             remote_participants: Default::default(),
+            _subscriptions: vec![client.add_message_handler(cx.handle(), Self::handle_room_updated)],
             client,
         }
     }
 
-    fn apply_update(&mut self, room: proto::Room, cx: &mut ModelContext<Self>) -> Result<()> {
+    pub fn remote_participants(&self) -> &HashMap<PeerId, RemoteParticipant> {
+        &self.remote_participants
+    }
+
+    async fn handle_room_updated(
+        this: ModelHandle<Self>,
+        envelope: TypedEnvelope<proto::RoomUpdated>,
+        _: Arc<Client>,
+        mut cx: AsyncAppContext,
+    ) -> Result<()> {
+        let room = envelope
+            .payload
+            .room
+            .ok_or_else(|| anyhow!("invalid room"))?;
+        this.update(&mut cx, |this, cx| this.apply_room_update(room, cx))?;
+        Ok(())
+    }
+
+    fn apply_room_update(&mut self, room: proto::Room, cx: &mut ModelContext<Self>) -> Result<()> {
+        // TODO: compute diff instead of clearing participants
+        self.remote_participants.clear();
         for participant in room.participants {
-            self.remote_participants.insert(
-                PeerId(participant.peer_id),
-                RemoteParticipant {
-                    user_id: participant.user_id,
-                    projects: Default::default(), // TODO: populate projects
-                    location: ParticipantLocation::from_proto(participant.location)?,
-                },
-            );
+            if Some(participant.user_id) != self.client.user_id() {
+                self.remote_participants.insert(
+                    PeerId(participant.peer_id),
+                    RemoteParticipant {
+                        user_id: participant.user_id,
+                        projects: Default::default(), // TODO: populate projects
+                        location: ParticipantLocation::from_proto(participant.location)?,
+                    },
+                );
+            }
         }
         cx.notify();
         Ok(())

crates/rpc/proto/zed.proto 🔗

@@ -17,7 +17,8 @@ message Envelope {
         JoinRoomResponse join_room_response = 11;
         Call call = 12;
         IncomingCall incoming_call = 1000;
-        RespondToCall respond_to_call = 13;
+        CancelCall cancel_call = 1001;
+        DeclineCall decline_call = 13;
         RoomUpdated room_updated = 14;
 
         RegisterProject register_project = 15;
@@ -184,9 +185,10 @@ message IncomingCall {
     repeated uint64 participant_user_ids = 3;
 }
 
-message RespondToCall {
+message CancelCall {}
+
+message DeclineCall {
     uint64 room_id = 1;
-    bool accept = 2;
 }
 
 message RoomUpdated {

crates/rpc/src/proto.rs 🔗

@@ -84,12 +84,14 @@ messages!(
     (BufferReloaded, Foreground),
     (BufferSaved, Foreground),
     (Call, Foreground),
+    (CancelCall, Foreground),
     (ChannelMessageSent, Foreground),
     (CopyProjectEntry, Foreground),
     (CreateBufferForPeer, Foreground),
     (CreateProjectEntry, Foreground),
     (CreateRoom, Foreground),
     (CreateRoomResponse, Foreground),
+    (DeclineCall, Foreground),
     (DeleteProjectEntry, Foreground),
     (Error, Foreground),
     (Follow, Foreground),
@@ -186,6 +188,7 @@ request_messages!(
     (CopyProjectEntry, ProjectEntryResponse),
     (CreateProjectEntry, ProjectEntryResponse),
     (CreateRoom, CreateRoomResponse),
+    (DeclineCall, Ack),
     (DeleteProjectEntry, ProjectEntryResponse),
     (Follow, FollowResponse),
     (FormatBuffers, FormatBuffersResponse),