Use `PeerId` in `TestServer::disconnect_client`

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/client/src/client.rs            |  8 +++
crates/collab/src/integration_tests.rs | 59 +++++++++++++++------------
crates/collab/src/rpc.rs               |  8 +-
3 files changed, 45 insertions(+), 30 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -319,6 +319,14 @@ impl Client {
             .map(|credentials| credentials.user_id)
     }
 
+    pub fn peer_id(&self) -> Option<PeerId> {
+        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
+            Some(*peer_id)
+        } else {
+            None
+        }
+    }
+
     pub fn status(&self) -> watch::Receiver<Status> {
         self.state.read().status.1.clone()
     }

crates/collab/src/integration_tests.rs 🔗

@@ -8,7 +8,7 @@ use anyhow::anyhow;
 use call::{room, ActiveCall, ParticipantLocation, Room};
 use client::{
     self, test::FakeHttpClient, Channel, ChannelDetails, ChannelList, Client, Connection,
-    Credentials, EstablishConnectionError, User, UserStore, RECEIVE_TIMEOUT,
+    Credentials, EstablishConnectionError, PeerId, User, UserStore, RECEIVE_TIMEOUT,
 };
 use collections::{BTreeMap, HashMap, HashSet};
 use editor::{
@@ -16,7 +16,10 @@ use editor::{
     ToggleCodeActions, Undo,
 };
 use fs::{FakeFs, Fs as _, HomeDir, LineEnding};
-use futures::{channel::mpsc, Future, StreamExt as _};
+use futures::{
+    channel::{mpsc, oneshot},
+    Future, StreamExt as _,
+};
 use gpui::{
     executor::{self, Deterministic},
     geometry::vector::vec2f,
@@ -34,7 +37,6 @@ use project::{
     ProjectStore, WorktreeId,
 };
 use rand::prelude::*;
-use rpc::PeerId;
 use serde_json::json;
 use settings::{Formatter, Settings};
 use sqlx::types::time::OffsetDateTime;
@@ -385,7 +387,7 @@ async fn test_leaving_room_on_disconnection(
     );
 
     // When user A disconnects, both client A and B clear their room on the active call.
-    server.disconnect_client(client_a.current_user_id(cx_a));
+    server.disconnect_client(client_a.peer_id().unwrap());
     cx_a.foreground().advance_clock(rpc::RECEIVE_TIMEOUT);
     active_call_a.read_with(cx_a, |call, _| assert!(call.room().is_none()));
     active_call_b.read_with(cx_b, |call, _| assert!(call.room().is_none()));
@@ -529,7 +531,7 @@ async fn test_calls_on_multiple_connections(
     assert!(incoming_call_b2.next().await.unwrap().is_some());
 
     // User A disconnects, causing both connections to stop ringing.
-    server.disconnect_client(client_a.current_user_id(cx_a));
+    server.disconnect_client(client_a.peer_id().unwrap());
     deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
     assert!(incoming_call_b1.next().await.unwrap().is_none());
     assert!(incoming_call_b2.next().await.unwrap().is_none());
@@ -547,7 +549,7 @@ async fn test_calls_on_multiple_connections(
 
     // User B disconnects all clients, causing user A to no longer see a pending call for them.
     println!("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
-    server.disconnect_client(client_b1.current_user_id(cx_b1));
+    server.disconnect_client(client_b1.peer_id().unwrap());
     deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
     active_call_a.read_with(cx_a, |call, _| assert!(call.room().is_none()));
 }
@@ -607,7 +609,7 @@ async fn test_share_project(
         .update(cx_b, |call, cx| call.accept_incoming(cx))
         .await
         .unwrap();
-    let client_b_peer_id = client_b.peer_id;
+    let client_b_peer_id = client_b.peer_id().unwrap();
     let project_b = client_b
         .build_remote_project(initial_project.id, cx_b)
         .await;
@@ -831,7 +833,7 @@ async fn test_host_disconnect(
     assert!(cx_b.is_window_edited(workspace_b.window_id()));
 
     // Drop client A's connection. Collaborators should disappear and the project should not be shown as shared.
-    server.disconnect_client(client_a.current_user_id(cx_a));
+    server.disconnect_client(client_a.peer_id().unwrap());
     deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
     project_a
         .condition(cx_a, |project, _| project.collaborators().is_empty())
@@ -874,7 +876,7 @@ async fn test_host_disconnect(
         .unwrap();
 
     // Drop client A's connection again. We should still unshare it successfully.
-    server.disconnect_client(client_a.current_user_id(cx_a));
+    server.disconnect_client(client_a.peer_id().unwrap());
     deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
     project_a.read_with(cx_a, |project, _| assert!(!project.is_shared()));
 }
@@ -2175,7 +2177,7 @@ async fn test_leaving_project(
 
     // Simulate connection loss for client C and ensure client A observes client C leaving the project.
     client_c.wait_for_current_user(cx_c).await;
-    server.disconnect_client(client_c.current_user_id(cx_c));
+    server.disconnect_client(client_c.peer_id().unwrap());
     cx_a.foreground().advance_clock(rpc::RECEIVE_TIMEOUT);
     deterministic.run_until_parked();
     project_a.read_with(cx_a, |project, _| {
@@ -4338,7 +4340,7 @@ async fn test_chat_reconnection(cx_a: &mut TestAppContext, cx_b: &mut TestAppCon
 
     // Disconnect client B, ensuring we can still access its cached channel data.
     server.forbid_connections();
-    server.disconnect_client(client_b.current_user_id(cx_b));
+    server.disconnect_client(client_b.peer_id().unwrap());
     cx_b.foreground().advance_clock(rpc::RECEIVE_TIMEOUT);
     while !matches!(
         status_b.next().await,
@@ -4501,7 +4503,7 @@ async fn test_contacts(
         ]
     );
 
-    server.disconnect_client(client_c.current_user_id(cx_c));
+    server.disconnect_client(client_c.peer_id().unwrap());
     server.forbid_connections();
     deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
     assert_eq!(
@@ -4741,7 +4743,7 @@ async fn test_contacts(
     );
 
     server.forbid_connections();
-    server.disconnect_client(client_a.current_user_id(cx_a));
+    server.disconnect_client(client_a.peer_id().unwrap());
     deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
     assert_eq!(contacts(&client_a, cx_a), []);
     assert_eq!(
@@ -5651,6 +5653,7 @@ async fn test_random_collaboration(
 
     let mut clients = Vec::new();
     let mut user_ids = Vec::new();
+    let mut peer_ids = Vec::new();
     let mut op_start_signals = Vec::new();
 
     let mut next_entity_id = 100000;
@@ -5839,6 +5842,7 @@ async fn test_random_collaboration(
 
     let op_start_signal = futures::channel::mpsc::unbounded();
     user_ids.push(host_user_id);
+    peer_ids.push(host.peer_id().unwrap());
     op_start_signals.push(op_start_signal.0);
     clients.push(host_cx.foreground().spawn(host.simulate_host(
         host_project,
@@ -5856,7 +5860,7 @@ async fn test_random_collaboration(
     let mut operations = 0;
     while operations < max_operations {
         if operations == disconnect_host_at {
-            server.disconnect_client(user_ids[0]);
+            server.disconnect_client(peer_ids[0]);
             deterministic.advance_clock(RECEIVE_TIMEOUT);
             drop(op_start_signals);
 
@@ -5939,6 +5943,7 @@ async fn test_random_collaboration(
 
                 let op_start_signal = futures::channel::mpsc::unbounded();
                 user_ids.push(guest_user_id);
+                peer_ids.push(guest.peer_id().unwrap());
                 op_start_signals.push(op_start_signal.0);
                 clients.push(guest_cx.foreground().spawn(guest.simulate_guest(
                     guest_username.clone(),
@@ -5955,10 +5960,11 @@ async fn test_random_collaboration(
                 let guest_ix = rng.lock().gen_range(1..clients.len());
                 log::info!("Removing guest {}", user_ids[guest_ix]);
                 let removed_guest_id = user_ids.remove(guest_ix);
+                let removed_peer_id = peer_ids.remove(guest_ix);
                 let guest = clients.remove(guest_ix);
                 op_start_signals.remove(guest_ix);
                 server.forbid_connections();
-                server.disconnect_client(removed_guest_id);
+                server.disconnect_client(removed_peer_id);
                 deterministic.advance_clock(RECEIVE_TIMEOUT);
                 deterministic.start_waiting();
                 log::info!("Waiting for guest {} to exit...", removed_guest_id);
@@ -6082,8 +6088,10 @@ async fn test_random_collaboration(
             let host_buffer = host_project.read_with(&host_cx, |project, cx| {
                 project.buffer_for_id(buffer_id, cx).unwrap_or_else(|| {
                     panic!(
-                        "host does not have buffer for guest:{}, peer:{}, id:{}",
-                        guest_client.username, guest_client.peer_id, buffer_id
+                        "host does not have buffer for guest:{}, peer:{:?}, id:{}",
+                        guest_client.username,
+                        guest_client.peer_id(),
+                        buffer_id
                     )
                 })
             });
@@ -6126,7 +6134,7 @@ struct TestServer {
     server: Arc<Server>,
     foreground: Rc<executor::Foreground>,
     notifications: mpsc::UnboundedReceiver<()>,
-    connection_killers: Arc<Mutex<HashMap<UserId, Arc<AtomicBool>>>>,
+    connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
     forbid_connections: Arc<AtomicBool>,
     _test_db: TestDb,
 }
@@ -6192,7 +6200,6 @@ impl TestServer {
         let db = self.app_state.db.clone();
         let connection_killers = self.connection_killers.clone();
         let forbid_connections = self.forbid_connections.clone();
-        let (connection_id_tx, mut connection_id_rx) = mpsc::channel(16);
 
         Arc::get_mut(&mut client)
             .unwrap()
@@ -6215,7 +6222,6 @@ impl TestServer {
                 let connection_killers = connection_killers.clone();
                 let forbid_connections = forbid_connections.clone();
                 let client_name = client_name.clone();
-                let connection_id_tx = connection_id_tx.clone();
                 cx.spawn(move |cx| async move {
                     if forbid_connections.load(SeqCst) {
                         Err(EstablishConnectionError::other(anyhow!(
@@ -6224,7 +6230,7 @@ impl TestServer {
                     } else {
                         let (client_conn, server_conn, killed) =
                             Connection::in_memory(cx.background());
-                        connection_killers.lock().insert(user_id, killed);
+                        let (connection_id_tx, connection_id_rx) = oneshot::channel();
                         let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
                         cx.background()
                             .spawn(server.handle_connection(
@@ -6235,6 +6241,10 @@ impl TestServer {
                                 cx.background(),
                             ))
                             .detach();
+                        let connection_id = connection_id_rx.await.unwrap();
+                        connection_killers
+                            .lock()
+                            .insert(PeerId(connection_id.0), killed);
                         Ok(client_conn)
                     }
                 })
@@ -6266,11 +6276,9 @@ impl TestServer {
             .authenticate_and_connect(false, &cx.to_async())
             .await
             .unwrap();
-        let peer_id = PeerId(connection_id_rx.next().await.unwrap().0);
 
         let client = TestClient {
             client,
-            peer_id,
             username: name.to_string(),
             user_store,
             project_store,
@@ -6282,10 +6290,10 @@ impl TestServer {
         client
     }
 
-    fn disconnect_client(&self, user_id: UserId) {
+    fn disconnect_client(&self, peer_id: PeerId) {
         self.connection_killers
             .lock()
-            .remove(&user_id)
+            .remove(&peer_id)
             .unwrap()
             .store(true, SeqCst);
     }
@@ -6386,7 +6394,6 @@ impl Drop for TestServer {
 struct TestClient {
     client: Arc<Client>,
     username: String,
-    pub peer_id: PeerId,
     pub user_store: ModelHandle<UserStore>,
     pub project_store: ModelHandle<ProjectStore>,
     language_registry: Arc<LanguageRegistry>,

crates/collab/src/rpc.rs 🔗

@@ -24,7 +24,7 @@ use axum::{
 };
 use collections::{HashMap, HashSet};
 use futures::{
-    channel::mpsc,
+    channel::{mpsc, oneshot},
     future::{self, BoxFuture},
     stream::FuturesUnordered,
     FutureExt, SinkExt, StreamExt, TryStreamExt,
@@ -348,7 +348,7 @@ impl Server {
         connection: Connection,
         address: String,
         user: User,
-        mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
+        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
         executor: E,
     ) -> impl Future<Output = Result<()>> {
         let mut this = self.clone();
@@ -372,8 +372,8 @@ impl Server {
             this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
             tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 
-            if let Some(send_connection_id) = send_connection_id.as_mut() {
-                let _ = send_connection_id.send(connection_id).await;
+            if let Some(send_connection_id) = send_connection_id.take() {
+                let _ = send_connection_id.send(connection_id);
             }
 
             if !user.connected_once {