From 05a662b35edfafb68ed57fcc2c27b7442291604d Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 17 Jun 2021 14:19:15 -0700 Subject: [PATCH] Allow peers to receive individual messages before starting message loop Co-Authored-By: Nathan Sobo --- zed-rpc/src/peer.rs | 118 ++++++++++++++++++++++++++----------------- zed/src/workspace.rs | 4 +- 2 files changed, 75 insertions(+), 47 deletions(-) diff --git a/zed-rpc/src/peer.rs b/zed-rpc/src/peer.rs index 333e4b1672babcf281ed9e30e52704dd23803733..3307a7f60a6a72fbbf26893230b48677afc92f77 100644 --- a/zed-rpc/src/peer.rs +++ b/zed-rpc/src/peer.rs @@ -21,12 +21,14 @@ use std::{ }; type BoxedWriter = Pin>; +type BoxedReader = Pin>; #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct ConnectionId(u32); struct Connection { writer: Mutex>, + reader: Mutex>, response_channels: Mutex>>, next_message_id: AtomicU32, } @@ -52,7 +54,8 @@ impl TypedEnvelope { } pub struct Peer { - connections: RwLock, barrier::Sender)>>, + connections: RwLock>>, + connection_close_barriers: RwLock>, message_handlers: RwLock>, handler_types: Mutex>, next_connection_id: AtomicU32, @@ -62,6 +65,7 @@ impl Peer { pub fn new() -> Arc { Arc::new(Self { connections: Default::default(), + connection_close_barriers: Default::default(), message_handlers: Default::default(), handler_types: Default::default(), next_connection_id: Default::default(), @@ -102,10 +106,7 @@ impl Peer { rx } - pub async fn add_connection( - self: &Arc, - conn: Conn, - ) -> (ConnectionId, impl Future>) + pub async fn add_connection(self: &Arc, conn: Conn) -> ConnectionId where Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -113,26 +114,44 @@ impl Peer { self.next_connection_id .fetch_add(1, atomic::Ordering::SeqCst), ); - let (close_tx, mut close_rx) = barrier::channel(); - let connection = Arc::new(Connection { - writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))), - response_channels: Default::default(), - next_message_id: Default::default(), - }); + self.connections.write().await.insert( + connection_id, + Arc::new(Connection { + reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))), + writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))), + response_channels: Default::default(), + next_message_id: Default::default(), + }), + ); + connection_id + } - self.connections + pub async fn disconnect(&self, connection_id: ConnectionId) { + self.connections.write().await.remove(&connection_id); + self.connection_close_barriers .write() .await - .insert(connection_id, (connection.clone(), close_tx)); + .remove(&connection_id); + } + pub fn handle_messages( + self: &Arc, + connection_id: ConnectionId, + ) -> impl Future> + 'static { + let (close_tx, mut close_rx) = barrier::channel(); let this = self.clone(); - let handler_future = async move { + async move { + this.connection_close_barriers + .write() + .await + .insert(connection_id, close_tx); + let connection = this.connection(connection_id).await?; let closed = close_rx.recv(); futures::pin_mut!(closed); - let mut stream = MessageStream::new(conn); loop { - let read_message = stream.read_message(); + let mut reader = connection.reader.lock().await; + let read_message = reader.read_message(); futures::pin_mut!(read_message); match futures::future::select(read_message, &mut closed).await { @@ -181,13 +200,23 @@ impl Peer { Either::Right(_) => return Ok(()), } } - }; - - (connection_id, handler_future) + } } - pub async fn disconnect(&self, connection_id: ConnectionId) { - self.connections.write().await.remove(&connection_id); + pub async fn receive( + self: &Arc, + connection_id: ConnectionId, + ) -> Result> { + let connection = self.connection(connection_id).await?; + let envelope = connection.reader.lock().await.read_message().await?; + let id = envelope.id; + let payload = + M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?; + Ok(TypedEnvelope { + id, + connection_id, + payload, + }) } pub fn request( @@ -271,7 +300,6 @@ impl Peer { .await .get(&id) .ok_or_else(|| anyhow!("unknown connection: {}", id.0))? - .0 .clone()) } } @@ -298,22 +326,22 @@ mod tests { let server = Peer::new(); let client1 = Peer::new(); let client2 = Peer::new(); - let (client1_conn_id, f1) = client1 + let client1_conn_id = client1 .add_connection(UnixStream::connect(&socket_path).await.unwrap()) .await; - let (client2_conn_id, f2) = client2 + let client2_conn_id = client2 .add_connection(UnixStream::connect(&socket_path).await.unwrap()) .await; - let (_, f3) = server + let server_conn_id1 = server .add_connection(listener.accept().await.unwrap().0) .await; - let (_, f4) = server + let server_conn_id2 = server .add_connection(listener.accept().await.unwrap().0) .await; - smol::spawn(f1).detach(); - smol::spawn(f2).detach(); - smol::spawn(f3).detach(); - smol::spawn(f4).detach(); + smol::spawn(client1.handle_messages(client1_conn_id)).detach(); + smol::spawn(client2.handle_messages(client2_conn_id)).detach(); + smol::spawn(server.handle_messages(server_conn_id1)).detach(); + smol::spawn(server.handle_messages(server_conn_id2)).detach(); // define the expected requests and responses let request1 = proto::OpenWorktree { @@ -428,21 +456,21 @@ mod tests { let (mut server_conn, _) = listener.accept().await.unwrap(); let client = Peer::new(); - let (connection_id, handler) = client.add_connection(client_conn).await; - smol::spawn(handler).detach(); + let connection_id = client.add_connection(client_conn).await; + let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) = + barrier::channel(); + let handle_messages = client.handle_messages(connection_id); + smol::spawn(async move { + handle_messages.await.unwrap(); + incoming_messages_ended_tx.send(()).await.unwrap(); + }) + .detach(); client.disconnect(connection_id).await; - // Try sending an empty payload over and over, until the client is dropped and hangs up. - loop { - match server_conn.write(&[]).await { - Ok(_) => {} - Err(err) => { - if err.kind() == io::ErrorKind::BrokenPipe { - break; - } - } - } - } + incoming_messages_ended_rx.recv().await; + + let err = server_conn.write(&[]).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); }); } @@ -456,8 +484,8 @@ mod tests { client_conn.close().await.unwrap(); let client = Peer::new(); - let (connection_id, handler) = client.add_connection(client_conn).await; - smol::spawn(handler).detach(); + let connection_id = client.add_connection(client_conn).await; + smol::spawn(client.handle_messages(connection_id)).detach(); let err = client .request( diff --git a/zed/src/workspace.rs b/zed/src/workspace.rs index 8625b2fa43ce73e661c1eba61ef9fffdee4271b5..ac65daec85ff497d35a7eae7b0396f4c0f84a2bd 100644 --- a/zed/src/workspace.rs +++ b/zed/src/workspace.rs @@ -691,8 +691,8 @@ impl Workspace { // a TLS stream using `native-tls`. let stream = smol::net::TcpStream::connect(rpc_address).await?; - let (connection_id, handler) = rpc.add_connection(stream).await; - executor.spawn(handler).detach(); + let connection_id = rpc.add_connection(stream).await; + executor.spawn(rpc.handle_messages(connection_id)).detach(); let auth_response = rpc .request(