Implement call cancellation

Antonio Scandurra created

Change summary

crates/call/src/call.rs                | 20 +++++-
crates/collab/src/integration_tests.rs | 82 ++++++++++++++++++++++++++++
crates/collab/src/rpc.rs               | 25 +++++++
crates/collab/src/rpc/store.rs         | 45 +++++++++++++++
crates/rpc/proto/zed.proto             |  9 ++
crates/rpc/src/proto.rs                |  2 
6 files changed, 176 insertions(+), 7 deletions(-)

Detailed changes

crates/call/src/call.rs 🔗

@@ -52,7 +52,7 @@ impl ActiveCall {
             incoming_call: watch::channel(),
             _subscriptions: vec![
                 client.add_request_handler(cx.handle(), Self::handle_incoming_call),
-                client.add_message_handler(cx.handle(), Self::handle_cancel_call),
+                client.add_message_handler(cx.handle(), Self::handle_call_canceled),
             ],
             client,
             user_store,
@@ -87,9 +87,9 @@ impl ActiveCall {
         Ok(proto::Ack {})
     }
 
-    async fn handle_cancel_call(
+    async fn handle_call_canceled(
         this: ModelHandle<Self>,
-        _: TypedEnvelope<proto::CancelCall>,
+        _: TypedEnvelope<proto::CallCanceled>,
         _: Arc<Client>,
         mut cx: AsyncAppContext,
     ) -> Result<()> {
@@ -140,6 +140,20 @@ impl ActiveCall {
         })
     }
 
+    pub fn cancel_invite(
+        &mut self,
+        recipient_user_id: u64,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
+        let client = self.client.clone();
+        cx.foreground().spawn(async move {
+            client
+                .request(proto::CancelCall { recipient_user_id })
+                .await?;
+            anyhow::Ok(())
+        })
+    }
+
     pub fn incoming(&self) -> watch::Receiver<Option<IncomingCall>> {
         self.incoming_call.1.clone()
     }

crates/collab/src/integration_tests.rs 🔗

@@ -401,6 +401,88 @@ async fn test_leaving_room_on_disconnection(
     );
 }
 
+#[gpui::test(iterations = 10)]
+async fn test_calls_on_multiple_connections(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b1: &mut TestAppContext,
+    cx_b2: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b1 = server.create_client(cx_b1, "user_b").await;
+    let _client_b2 = server.create_client(cx_b2, "user_b").await;
+    server
+        .make_contacts(&mut [(&client_a, cx_a), (&client_b1, cx_b1)])
+        .await;
+
+    let active_call_a = cx_a.read(ActiveCall::global);
+    let active_call_b1 = cx_b1.read(ActiveCall::global);
+    let active_call_b2 = cx_b2.read(ActiveCall::global);
+    let mut incoming_call_b1 = active_call_b1.read_with(cx_b1, |call, _| call.incoming());
+    let mut incoming_call_b2 = active_call_b2.read_with(cx_b2, |call, _| call.incoming());
+    assert!(incoming_call_b1.next().await.unwrap().is_none());
+    assert!(incoming_call_b2.next().await.unwrap().is_none());
+
+    // Call user B from client A, ensuring both clients for user B ring.
+    active_call_a
+        .update(cx_a, |call, cx| {
+            call.invite(client_b1.user_id().unwrap(), None, cx)
+        })
+        .await
+        .unwrap();
+    assert!(incoming_call_b1.next().await.unwrap().is_some());
+    assert!(incoming_call_b2.next().await.unwrap().is_some());
+
+    // User B declines the call on one of the two connections, causing both connections
+    // to stop ringing.
+    active_call_b2.update(cx_b2, |call, _| call.decline_incoming().unwrap());
+    assert!(incoming_call_b1.next().await.unwrap().is_none());
+    assert!(incoming_call_b2.next().await.unwrap().is_none());
+
+    // Call user B again from client A.
+    active_call_a
+        .update(cx_a, |call, cx| {
+            call.invite(client_b1.user_id().unwrap(), None, cx)
+        })
+        .await
+        .unwrap();
+    assert!(incoming_call_b1.next().await.unwrap().is_some());
+    assert!(incoming_call_b2.next().await.unwrap().is_some());
+
+    // User B accepts the call on one of the two connections, causing both connections
+    // to stop ringing.
+    active_call_b2
+        .update(cx_b2, |call, cx| call.accept_incoming(cx))
+        .await
+        .unwrap();
+    assert!(incoming_call_b1.next().await.unwrap().is_none());
+    assert!(incoming_call_b2.next().await.unwrap().is_none());
+
+    // User B hangs up, and user A calls them again.
+    active_call_b2.update(cx_b2, |call, cx| call.hang_up(cx).unwrap());
+    deterministic.run_until_parked();
+    active_call_a
+        .update(cx_a, |call, cx| {
+            call.invite(client_b1.user_id().unwrap(), None, cx)
+        })
+        .await
+        .unwrap();
+    assert!(incoming_call_b1.next().await.unwrap().is_some());
+    assert!(incoming_call_b2.next().await.unwrap().is_some());
+
+    // User A cancels the call, causing both connections to stop ringing.
+    active_call_a
+        .update(cx_a, |call, cx| {
+            call.cancel_invite(client_b1.user_id().unwrap(), cx)
+        })
+        .await
+        .unwrap();
+    assert!(incoming_call_b1.next().await.unwrap().is_none());
+    assert!(incoming_call_b2.next().await.unwrap().is_none());
+}
+
 #[gpui::test(iterations = 10)]
 async fn test_share_project(
     deterministic: Arc<Deterministic>,

crates/collab/src/rpc.rs 🔗

@@ -150,6 +150,7 @@ impl Server {
             .add_request_handler(Server::join_room)
             .add_message_handler(Server::leave_room)
             .add_request_handler(Server::call)
+            .add_request_handler(Server::cancel_call)
             .add_message_handler(Server::decline_call)
             .add_request_handler(Server::update_participant_location)
             .add_request_handler(Server::share_project)
@@ -599,7 +600,7 @@ impl Server {
         let (room, recipient_connection_ids) = store.join_room(room_id, request.sender_id)?;
         for recipient_id in recipient_connection_ids {
             self.peer
-                .send(recipient_id, proto::CancelCall {})
+                .send(recipient_id, proto::CallCanceled {})
                 .trace_err();
         }
         response.send(proto::JoinRoomResponse {
@@ -715,6 +716,26 @@ impl Server {
         Err(anyhow!("failed to ring call recipient"))?
     }
 
+    async fn cancel_call(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::CancelCall>,
+        response: Response<proto::CancelCall>,
+    ) -> Result<()> {
+        let mut store = self.store().await;
+        let (room, recipient_connection_ids) = store.cancel_call(
+            UserId::from_proto(request.payload.recipient_user_id),
+            request.sender_id,
+        )?;
+        for recipient_id in recipient_connection_ids {
+            self.peer
+                .send(recipient_id, proto::CallCanceled {})
+                .trace_err();
+        }
+        self.room_updated(room);
+        response.send(proto::Ack {})?;
+        Ok(())
+    }
+
     async fn decline_call(
         self: Arc<Server>,
         message: TypedEnvelope<proto::DeclineCall>,
@@ -723,7 +744,7 @@ impl Server {
         let (room, recipient_connection_ids) = store.call_declined(message.sender_id)?;
         for recipient_id in recipient_connection_ids {
             self.peer
-                .send(recipient_id, proto::CancelCall {})
+                .send(recipient_id, proto::CallCanceled {})
                 .trace_err();
         }
         self.room_updated(room);

crates/collab/src/rpc/store.rs 🔗

@@ -585,6 +585,51 @@ impl Store {
         Ok(room)
     }
 
+    pub fn cancel_call(
+        &mut self,
+        recipient_user_id: UserId,
+        canceller_connection_id: ConnectionId,
+    ) -> Result<(&proto::Room, HashSet<ConnectionId>)> {
+        let canceller_user_id = self.user_id_for_connection(canceller_connection_id)?;
+        let canceller = self
+            .connected_users
+            .get(&canceller_user_id)
+            .ok_or_else(|| anyhow!("no such connection"))?;
+        let recipient = self
+            .connected_users
+            .get(&recipient_user_id)
+            .ok_or_else(|| anyhow!("no such connection"))?;
+        let canceller_active_call = canceller
+            .active_call
+            .as_ref()
+            .ok_or_else(|| anyhow!("no active call"))?;
+        let recipient_active_call = recipient
+            .active_call
+            .as_ref()
+            .ok_or_else(|| anyhow!("no active call for recipient"))?;
+
+        anyhow::ensure!(
+            canceller_active_call.room_id == recipient_active_call.room_id,
+            "users are on different calls"
+        );
+        anyhow::ensure!(
+            recipient_active_call.connection_id.is_none(),
+            "recipient has already answered"
+        );
+        let room_id = recipient_active_call.room_id;
+        let room = self
+            .rooms
+            .get_mut(&room_id)
+            .ok_or_else(|| anyhow!("no such room"))?;
+        room.pending_user_ids
+            .retain(|user_id| UserId::from_proto(*user_id) != recipient_user_id);
+
+        let recipient = self.connected_users.get_mut(&recipient_user_id).unwrap();
+        recipient.active_call.take();
+
+        Ok((room, recipient.connection_ids.clone()))
+    }
+
     pub fn call_declined(
         &mut self,
         recipient_connection_id: ConnectionId,

crates/rpc/proto/zed.proto 🔗

@@ -18,7 +18,8 @@ message Envelope {
         LeaveRoom leave_room = 1002;
         Call call = 12;
         IncomingCall incoming_call = 1000;
-        CancelCall cancel_call = 1001;
+        CallCanceled call_canceled = 1001;
+        CancelCall cancel_call = 1004;
         DeclineCall decline_call = 13;
         UpdateParticipantLocation update_participant_location = 1003;
         RoomUpdated room_updated = 14;
@@ -189,7 +190,11 @@ message IncomingCall {
     optional uint64 initial_project_id = 4;
 }
 
-message CancelCall {}
+message CallCanceled {}
+
+message CancelCall {
+    uint64 recipient_user_id = 1;
+}
 
 message DeclineCall {}
 

crates/rpc/src/proto.rs 🔗

@@ -84,6 +84,7 @@ messages!(
     (BufferReloaded, Foreground),
     (BufferSaved, Foreground),
     (Call, Foreground),
+    (CallCanceled, Foreground),
     (CancelCall, Foreground),
     (ChannelMessageSent, Foreground),
     (CopyProjectEntry, Foreground),
@@ -183,6 +184,7 @@ request_messages!(
         ApplyCompletionAdditionalEditsResponse
     ),
     (Call, Ack),
+    (CancelCall, Ack),
     (CopyProjectEntry, ProjectEntryResponse),
     (CreateProjectEntry, ProjectEntryResponse),
     (CreateRoom, CreateRoomResponse),