Introduce call infrastructure

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/client/src/call.rs              |  3 
crates/client/src/channel.rs           |  2 
crates/client/src/client.rs            | 23 ++++++
crates/client/src/user.rs              | 97 ++++++++++++++++++++-------
crates/collab/src/integration_tests.rs | 12 +-
crates/collab/src/rpc.rs               | 63 ++++++++++++++++++
crates/collab/src/rpc/store.rs         | 39 +++++++++++
crates/project/src/project.rs          |  6 
crates/room/src/room.rs                | 24 +++++-
crates/rpc/proto/zed.proto             | 20 ++++
crates/rpc/src/proto.rs                |  5 +
11 files changed, 249 insertions(+), 45 deletions(-)

Detailed changes

crates/client/src/call.rs 🔗

@@ -3,6 +3,7 @@ use std::sync::Arc;
 
 #[derive(Clone)]
 pub struct Call {
-    pub from: Vec<Arc<User>>,
     pub room_id: u64,
+    pub from: Arc<User>,
+    pub participants: Vec<Arc<User>>,
 }

crates/client/src/channel.rs 🔗

@@ -530,7 +530,7 @@ impl ChannelMessage {
     ) -> Result<Self> {
         let sender = user_store
             .update(cx, |user_store, cx| {
-                user_store.fetch_user(message.sender_id, cx)
+                user_store.get_user(message.sender_id, cx)
             })
             .await?;
         Ok(ChannelMessage {

crates/client/src/client.rs 🔗

@@ -422,6 +422,29 @@ impl Client {
         }
     }
 
+    pub fn add_request_handler<M, E, H, F>(
+        self: &Arc<Self>,
+        model: ModelHandle<E>,
+        handler: H,
+    ) -> Subscription
+    where
+        M: RequestMessage,
+        E: Entity,
+        H: 'static
+            + Send
+            + Sync
+            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+        F: 'static + Future<Output = Result<M::Response>>,
+    {
+        self.add_message_handler(model, move |handle, envelope, this, cx| {
+            Self::respond_to_request(
+                envelope.receipt(),
+                handler(handle, envelope, this.clone(), cx),
+                this,
+            )
+        })
+    }
+
     pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
     where
         M: EntityMessage,

crates/client/src/user.rs 🔗

@@ -5,7 +5,7 @@ use anyhow::{anyhow, Context, Result};
 use collections::{hash_map::Entry, BTreeSet, HashMap, HashSet};
 use futures::{channel::mpsc, future, AsyncReadExt, Future, Stream, StreamExt};
 use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task};
-use postage::{broadcast, sink::Sink, watch};
+use postage::{sink::Sink, watch};
 use rpc::proto::{RequestMessage, UsersResponse};
 use std::sync::{Arc, Weak};
 use util::TryFutureExt as _;
@@ -68,7 +68,7 @@ pub struct UserStore {
     outgoing_contact_requests: Vec<Arc<User>>,
     pending_contact_requests: HashMap<u64, usize>,
     invite_info: Option<InviteInfo>,
-    incoming_calls: broadcast::Sender<Call>,
+    incoming_calls: Vec<mpsc::UnboundedSender<Call>>,
     client: Weak<Client>,
     http: Arc<dyn HttpClient>,
     _maintain_contacts: Task<()>,
@@ -118,8 +118,8 @@ impl UserStore {
             client.add_message_handler(cx.handle(), Self::handle_update_contacts),
             client.add_message_handler(cx.handle(), Self::handle_update_invite_info),
             client.add_message_handler(cx.handle(), Self::handle_show_contacts),
+            client.add_request_handler(cx.handle(), Self::handle_incoming_call),
         ];
-        let (incoming_calls, _) = broadcast::channel(32);
         Self {
             users: Default::default(),
             current_user: current_user_rx,
@@ -127,7 +127,7 @@ impl UserStore {
             incoming_contact_requests: Default::default(),
             outgoing_contact_requests: Default::default(),
             invite_info: None,
-            incoming_calls,
+            incoming_calls: Default::default(),
             client: Arc::downgrade(&client),
             update_contacts_tx,
             http,
@@ -148,7 +148,7 @@ impl UserStore {
                         Status::Connected { .. } => {
                             if let Some((this, user_id)) = this.upgrade(&cx).zip(client.user_id()) {
                                 let user = this
-                                    .update(&mut cx, |this, cx| this.fetch_user(user_id, cx))
+                                    .update(&mut cx, |this, cx| this.get_user(user_id, cx))
                                     .log_err()
                                     .await;
                                 current_user_tx.send(user).await.ok();
@@ -199,12 +199,41 @@ impl UserStore {
         Ok(())
     }
 
+    async fn handle_incoming_call(
+        this: ModelHandle<Self>,
+        envelope: TypedEnvelope<proto::IncomingCall>,
+        _: Arc<Client>,
+        mut cx: AsyncAppContext,
+    ) -> Result<proto::Ack> {
+        let call = Call {
+            room_id: envelope.payload.room_id,
+            participants: this
+                .update(&mut cx, |this, cx| {
+                    this.get_users(envelope.payload.participant_user_ids, cx)
+                })
+                .await?,
+            from: this
+                .update(&mut cx, |this, cx| {
+                    this.get_user(envelope.payload.from_user_id, cx)
+                })
+                .await?,
+        };
+        this.update(&mut cx, |this, _| {
+            this.incoming_calls
+                .retain(|tx| tx.unbounded_send(call.clone()).is_ok());
+        });
+
+        Ok(proto::Ack {})
+    }
+
     pub fn invite_info(&self) -> Option<&InviteInfo> {
         self.invite_info.as_ref()
     }
 
-    pub fn incoming_calls(&self) -> impl 'static + Stream<Item = Call> {
-        self.incoming_calls.subscribe()
+    pub fn incoming_calls(&mut self) -> impl 'static + Stream<Item = Call> {
+        let (tx, rx) = mpsc::unbounded();
+        self.incoming_calls.push(tx);
+        rx
     }
 
     async fn handle_update_contacts(
@@ -266,9 +295,7 @@ impl UserStore {
                     for request in message.incoming_requests {
                         incoming_requests.push({
                             let user = this
-                                .update(&mut cx, |this, cx| {
-                                    this.fetch_user(request.requester_id, cx)
-                                })
+                                .update(&mut cx, |this, cx| this.get_user(request.requester_id, cx))
                                 .await?;
                             (user, request.should_notify)
                         });
@@ -277,7 +304,7 @@ impl UserStore {
                     let mut outgoing_requests = Vec::new();
                     for requested_user_id in message.outgoing_requests {
                         outgoing_requests.push(
-                            this.update(&mut cx, |this, cx| this.fetch_user(requested_user_id, cx))
+                            this.update(&mut cx, |this, cx| this.get_user(requested_user_id, cx))
                                 .await?,
                         );
                     }
@@ -518,19 +545,37 @@ impl UserStore {
 
     pub fn get_users(
         &mut self,
-        mut user_ids: Vec<u64>,
+        user_ids: Vec<u64>,
         cx: &mut ModelContext<Self>,
-    ) -> Task<Result<()>> {
-        user_ids.retain(|id| !self.users.contains_key(id));
-        if user_ids.is_empty() {
-            Task::ready(Ok(()))
-        } else {
-            let load = self.load_users(proto::GetUsers { user_ids }, cx);
-            cx.foreground().spawn(async move {
-                load.await?;
-                Ok(())
+    ) -> Task<Result<Vec<Arc<User>>>> {
+        let mut user_ids_to_fetch = user_ids.clone();
+        user_ids_to_fetch.retain(|id| !self.users.contains_key(id));
+
+        cx.spawn(|this, mut cx| async move {
+            if !user_ids_to_fetch.is_empty() {
+                this.update(&mut cx, |this, cx| {
+                    this.load_users(
+                        proto::GetUsers {
+                            user_ids: user_ids_to_fetch,
+                        },
+                        cx,
+                    )
+                })
+                .await?;
+            }
+
+            this.read_with(&cx, |this, _| {
+                user_ids
+                    .iter()
+                    .map(|user_id| {
+                        this.users
+                            .get(user_id)
+                            .cloned()
+                            .ok_or_else(|| anyhow!("user {} not found", user_id))
+                    })
+                    .collect()
             })
-        }
+        })
     }
 
     pub fn fuzzy_search_users(
@@ -541,7 +586,7 @@ impl UserStore {
         self.load_users(proto::FuzzySearchUsers { query }, cx)
     }
 
-    pub fn fetch_user(
+    pub fn get_user(
         &mut self,
         user_id: u64,
         cx: &mut ModelContext<Self>,
@@ -621,7 +666,7 @@ impl Contact {
     ) -> Result<Self> {
         let user = user_store
             .update(cx, |user_store, cx| {
-                user_store.fetch_user(contact.user_id, cx)
+                user_store.get_user(contact.user_id, cx)
             })
             .await?;
         let mut projects = Vec::new();
@@ -630,9 +675,7 @@ impl Contact {
             for participant_id in project.guests {
                 guests.insert(
                     user_store
-                        .update(cx, |user_store, cx| {
-                            user_store.fetch_user(participant_id, cx)
-                        })
+                        .update(cx, |user_store, cx| user_store.get_user(participant_id, cx))
                         .await?,
                 );
             }

crates/collab/src/integration_tests.rs 🔗

@@ -100,16 +100,16 @@ async fn test_share_project_in_room(
 
     let mut incoming_calls_b = client_b
         .user_store
-        .read_with(cx_b, |user, _| user.incoming_calls());
-    let user_b_joined = room_a.update(cx_a, |room, cx| {
-        room.invite(client_b.user_id().unwrap(), cx)
-    });
+        .update(cx_b, |user, _| user.incoming_calls());
+    room_a
+        .update(cx_a, |room, cx| room.call(client_b.user_id().unwrap(), cx))
+        .await
+        .unwrap();
     let call_b = incoming_calls_b.next().await.unwrap();
     let room_b = cx_b
         .update(|cx| Room::join(call_b.room_id, client_b.clone(), cx))
         .await
         .unwrap();
-    user_b_joined.await.unwrap();
 }
 
 #[gpui::test(iterations = 10)]
@@ -512,7 +512,7 @@ async fn test_cancel_join_request(
     let user_b = client_a
         .user_store
         .update(cx_a, |store, cx| {
-            store.fetch_user(client_b.user_id().unwrap(), cx)
+            store.get_user(client_b.user_id().unwrap(), cx)
         })
         .await
         .unwrap();

crates/collab/src/rpc.rs 🔗

@@ -152,6 +152,7 @@ impl Server {
         server
             .add_request_handler(Server::ping)
             .add_request_handler(Server::create_room)
+            .add_request_handler(Server::call)
             .add_request_handler(Server::register_project)
             .add_request_handler(Server::unregister_project)
             .add_request_handler(Server::join_project)
@@ -604,6 +605,68 @@ impl Server {
         Ok(())
     }
 
+    async fn call(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::Call>,
+        response: Response<proto::Call>,
+    ) -> Result<()> {
+        let to_user_id = UserId::from_proto(request.payload.to_user_id);
+        let room_id = request.payload.room_id;
+        let (from_user_id, receiver_ids, room) =
+            self.store()
+                .await
+                .call(room_id, request.sender_id, to_user_id)?;
+        for participant in &room.participants {
+            self.peer
+                .send(
+                    ConnectionId(participant.peer_id),
+                    proto::RoomUpdated {
+                        room: Some(room.clone()),
+                    },
+                )
+                .trace_err();
+        }
+
+        let mut calls = receiver_ids
+            .into_iter()
+            .map(|receiver_id| {
+                self.peer.request(
+                    receiver_id,
+                    proto::IncomingCall {
+                        room_id,
+                        from_user_id: from_user_id.to_proto(),
+                        participant_user_ids: room.participants.iter().map(|p| p.user_id).collect(),
+                    },
+                )
+            })
+            .collect::<FuturesUnordered<_>>();
+
+        while let Some(call_response) = calls.next().await {
+            match call_response.as_ref() {
+                Ok(_) => {
+                    response.send(proto::Ack {})?;
+                    return Ok(());
+                }
+                Err(_) => {
+                    call_response.trace_err();
+                }
+            }
+        }
+
+        let room = self.store().await.call_failed(room_id, to_user_id)?;
+        for participant in &room.participants {
+            self.peer
+                .send(
+                    ConnectionId(participant.peer_id),
+                    proto::RoomUpdated {
+                        room: Some(room.clone()),
+                    },
+                )
+                .trace_err();
+        }
+        Err(anyhow!("failed to ring call recipient"))?
+    }
+
     async fn register_project(
         self: Arc<Server>,
         request: TypedEnvelope<proto::RegisterProject>,

crates/collab/src/rpc/store.rs 🔗

@@ -351,6 +351,45 @@ impl Store {
         Ok(room_id)
     }
 
+    pub fn call(
+        &mut self,
+        room_id: RoomId,
+        from_connection_id: ConnectionId,
+        to_user_id: UserId,
+    ) -> Result<(UserId, Vec<ConnectionId>, proto::Room)> {
+        let from_user_id = self.user_id_for_connection(from_connection_id)?;
+        let to_connection_ids = self.connection_ids_for_user(to_user_id).collect::<Vec<_>>();
+        let room = self
+            .rooms
+            .get_mut(&room_id)
+            .ok_or_else(|| anyhow!("no such room"))?;
+        anyhow::ensure!(
+            room.participants
+                .iter()
+                .any(|participant| participant.peer_id == from_connection_id.0),
+            "no such room"
+        );
+        anyhow::ensure!(
+            room.pending_calls_to_user_ids
+                .iter()
+                .all(|user_id| UserId::from_proto(*user_id) != to_user_id),
+            "cannot call the same user more than once"
+        );
+        room.pending_calls_to_user_ids.push(to_user_id.to_proto());
+
+        Ok((from_user_id, to_connection_ids, room.clone()))
+    }
+
+    pub fn call_failed(&mut self, room_id: RoomId, to_user_id: UserId) -> Result<proto::Room> {
+        let room = self
+            .rooms
+            .get_mut(&room_id)
+            .ok_or_else(|| anyhow!("no such room"))?;
+        room.pending_calls_to_user_ids
+            .retain(|user_id| UserId::from_proto(*user_id) != to_user_id);
+        Ok(room.clone())
+    }
+
     pub fn register_project(
         &mut self,
         host_connection_id: ConnectionId,

crates/project/src/project.rs 🔗

@@ -4744,7 +4744,7 @@ impl Project {
         } else {
             let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
             let user = user_store
-                .update(&mut cx, |store, cx| store.fetch_user(user_id, cx))
+                .update(&mut cx, |store, cx| store.get_user(user_id, cx))
                 .await?;
             this.update(&mut cx, |_, cx| cx.emit(Event::ContactRequestedJoin(user)));
         }
@@ -4828,7 +4828,7 @@ impl Project {
         let user = this
             .update(&mut cx, |this, cx| {
                 this.user_store.update(cx, |user_store, cx| {
-                    user_store.fetch_user(envelope.payload.requester_id, cx)
+                    user_store.get_user(envelope.payload.requester_id, cx)
                 })
             })
             .await?;
@@ -6258,7 +6258,7 @@ impl Collaborator {
         cx: &mut AsyncAppContext,
     ) -> impl Future<Output = Result<Self>> {
         let user = user_store.update(cx, |user_store, cx| {
-            user_store.fetch_user(message.user_id, cx)
+            user_store.get_user(message.user_id, cx)
         });
 
         async move {

crates/room/src/room.rs 🔗

@@ -12,6 +12,11 @@ pub enum Event {
     PeerChangedActiveProject,
 }
 
+pub enum CallResponse {
+    Accepted,
+    Rejected,
+}
+
 pub struct Room {
     id: u64,
     local_participant: LocalParticipant,
@@ -43,7 +48,7 @@ impl Room {
             let response = client.request(proto::JoinRoom { id }).await?;
             let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
             let room = cx.add_model(|cx| Self::new(id, client, cx));
-            room.update(&mut cx, |room, cx| room.refresh(room_proto, cx))?;
+            room.update(&mut cx, |room, cx| room.apply_update(room_proto, cx))?;
             Ok(room)
         })
     }
@@ -59,7 +64,7 @@ impl Room {
         }
     }
 
-    fn refresh(&mut self, room: proto::Room, cx: &mut ModelContext<Self>) -> Result<()> {
+    fn apply_update(&mut self, room: proto::Room, cx: &mut ModelContext<Self>) -> Result<()> {
         for participant in room.participants {
             self.remote_participants.insert(
                 PeerId(participant.peer_id),
@@ -70,11 +75,22 @@ impl Room {
                 },
             );
         }
+        cx.notify();
         Ok(())
     }
 
-    pub fn invite(&mut self, user_id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
-        todo!()
+    pub fn call(&mut self, to_user_id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+        let client = self.client.clone();
+        let room_id = self.id;
+        cx.foreground().spawn(async move {
+            client
+                .request(proto::Call {
+                    room_id,
+                    to_user_id,
+                })
+                .await?;
+            Ok(())
+        })
     }
 
     pub async fn publish_project(&mut self, project: ModelHandle<Project>) -> Result<()> {

crates/rpc/proto/zed.proto 🔗

@@ -16,7 +16,8 @@ message Envelope {
         JoinRoom join_room = 10;
         JoinRoomResponse join_room_response = 11;
         Call call = 12;
-        CallResponse call_response = 13;
+        IncomingCall incoming_call = 1000;
+        RespondToCall respond_to_call = 13;
         RoomUpdated room_updated = 14;
 
         RegisterProject register_project = 15;
@@ -149,6 +150,7 @@ message JoinRoomResponse {
 
 message Room {
     repeated Participant participants = 1;
+    repeated uint64 pending_calls_to_user_ids = 2;
 }
 
 message Participant {
@@ -171,9 +173,21 @@ message ParticipantLocation {
     message External {}
 }
 
-message Call {}
+message Call {
+    uint64 room_id = 1;
+    uint64 to_user_id = 2;
+}
+
+message IncomingCall {
+    uint64 room_id = 1;
+    uint64 from_user_id = 2;
+    repeated uint64 participant_user_ids = 3;
+}
 
-message CallResponse {}
+message RespondToCall {
+    uint64 room_id = 1;
+    bool accept = 2;
+}
 
 message RoomUpdated {
     Room room = 1;

crates/rpc/src/proto.rs 🔗

@@ -83,6 +83,7 @@ messages!(
     (ApplyCompletionAdditionalEditsResponse, Background),
     (BufferReloaded, Foreground),
     (BufferSaved, Foreground),
+    (Call, Foreground),
     (ChannelMessageSent, Foreground),
     (CopyProjectEntry, Foreground),
     (CreateBufferForPeer, Foreground),
@@ -117,6 +118,7 @@ messages!(
     (GetProjectSymbols, Background),
     (GetProjectSymbolsResponse, Background),
     (GetUsers, Foreground),
+    (IncomingCall, Foreground),
     (UsersResponse, Foreground),
     (JoinChannel, Foreground),
     (JoinChannelResponse, Foreground),
@@ -151,6 +153,7 @@ messages!(
     (RequestJoinProject, Foreground),
     (RespondToContactRequest, Foreground),
     (RespondToJoinProjectRequest, Foreground),
+    (RoomUpdated, Foreground),
     (SaveBuffer, Foreground),
     (SearchProject, Background),
     (SearchProjectResponse, Background),
@@ -179,6 +182,7 @@ request_messages!(
         ApplyCompletionAdditionalEdits,
         ApplyCompletionAdditionalEditsResponse
     ),
+    (Call, Ack),
     (CopyProjectEntry, ProjectEntryResponse),
     (CreateProjectEntry, ProjectEntryResponse),
     (CreateRoom, CreateRoomResponse),
@@ -200,6 +204,7 @@ request_messages!(
     (JoinChannel, JoinChannelResponse),
     (JoinProject, JoinProjectResponse),
     (JoinRoom, JoinRoomResponse),
+    (IncomingCall, Ack),
     (OpenBufferById, OpenBufferResponse),
     (OpenBufferByPath, OpenBufferResponse),
     (OpenBufferForSymbol, OpenBufferForSymbolResponse),