From 3db215418c609a2cbe7cdde96ff25967aba835f3 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 15 Jun 2021 14:12:42 -0700 Subject: [PATCH] Allow RpcClient to encapsulate arbitrarily many connections --- zed-rpc/src/proto.rs | 2 +- zed/src/rpc_client.rs | 170 ++++++++++++++++++++++++------------------ zed/src/workspace.rs | 18 +++-- zed/src/worktree.rs | 12 ++- 4 files changed, 119 insertions(+), 83 deletions(-) diff --git a/zed-rpc/src/proto.rs b/zed-rpc/src/proto.rs index ac17390d4809e381f367dd71733a942123f6490b..86200c08422a698bf33a658c6c4247acfe71bd15 100644 --- a/zed-rpc/src/proto.rs +++ b/zed-rpc/src/proto.rs @@ -5,7 +5,7 @@ use std::{convert::TryInto, io}; include!(concat!(env!("OUT_DIR"), "/zed.messages.rs")); -pub trait EnvelopedMessage: Sized { +pub trait EnvelopedMessage: Sized + Send + 'static { fn into_envelope(self, id: u32, responding_to: Option) -> Envelope; fn from_envelope(envelope: Envelope) -> Option; } diff --git a/zed/src/rpc_client.rs b/zed/src/rpc_client.rs index 960536a503d4867c5a97a7b5082427e85befe99c..12ea77dae76ac73422d4ed0f806dcc6ede992563 100644 --- a/zed/src/rpc_client.rs +++ b/zed/src/rpc_client.rs @@ -6,7 +6,7 @@ use postage::{ prelude::{Sink, Stream}, }; use smol::{ - io::{BoxedWriter, ReadHalf}, + io::BoxedWriter, lock::Mutex, prelude::{AsyncRead, AsyncWrite}, }; @@ -20,89 +20,103 @@ use std::{ }; use zed_rpc::proto::{self, EnvelopedMessage, MessageStream, RequestMessage}; +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct ConnectionId(u32); + pub struct RpcClient { response_channels: Arc>>>, - outgoing: Mutex>, + outgoing: Arc>>>, next_message_id: AtomicU32, + next_connection_id: AtomicU32, _drop_tx: barrier::Sender, + drop_rx: barrier::Receiver, } impl RpcClient { - pub fn new(conn: Conn, executor: Arc) -> Arc - where - Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, - { - let response_channels = Arc::new(Mutex::new(HashMap::new())); - let (conn_rx, conn_tx) = smol::io::split(conn); + pub fn new() -> Arc { let (_drop_tx, drop_rx) = barrier::channel(); - - executor - .spawn(Self::handle_incoming( - conn_rx, - drop_rx, - response_channels.clone(), - )) - .detach(); - Arc::new(Self { - response_channels, - outgoing: Mutex::new(MessageStream::new(Box::pin(conn_tx))), + response_channels: Default::default(), + outgoing: Default::default(), + next_message_id: Default::default(), + next_connection_id: Default::default(), _drop_tx, - next_message_id: AtomicU32::new(0), + drop_rx, }) } - async fn handle_incoming( - conn: ReadHalf, - mut drop_rx: barrier::Receiver, - response_channels: Arc>>>, - ) where - Conn: AsyncRead + Unpin, + pub async fn connect(&self, conn: Conn, executor: Arc) -> ConnectionId + where + Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let dropped = drop_rx.recv(); - smol::pin!(dropped); + let connection_id = ConnectionId( + self.next_connection_id + .fetch_add(1, atomic::Ordering::SeqCst), + ); + let (conn_rx, conn_tx) = smol::io::split(conn); + let response_channels = self.response_channels.clone(); + let mut drop_rx = self.drop_rx.clone(); + let outgoing = self.outgoing.clone(); - let mut stream = MessageStream::new(conn); - loop { - let read_message = stream.read_message(); - smol::pin!(read_message); + executor + .spawn(async move { + let dropped = drop_rx.recv(); + smol::pin!(dropped); + + let mut stream = MessageStream::new(conn_rx); + loop { + let read_message = stream.read_message(); + smol::pin!(read_message); - match futures::future::select(read_message, &mut dropped).await { - Either::Left((Ok(incoming), _)) => { - if let Some(responding_to) = incoming.responding_to { - let channel = response_channels.lock().await.remove(&responding_to); - if let Some(mut tx) = channel { - tx.send(incoming).await.ok(); - } else { - log::warn!( - "received RPC response to unknown request {}", - responding_to - ); + match futures::future::select(read_message, &mut dropped).await { + Either::Left((Ok(incoming), _)) => { + if let Some(responding_to) = incoming.responding_to { + let channel = response_channels.lock().await.remove(&responding_to); + if let Some(mut tx) = channel { + tx.send(incoming).await.ok(); + } else { + log::warn!( + "received RPC response to unknown request {}", + responding_to + ); + } + } else { + // unprompted message from server + } } - } else { - // unprompted message from server + Either::Left((Err(error), _)) => { + log::warn!("received invalid RPC message {:?}", error); + } + Either::Right(_) => break, } } - Either::Left((Err(error), _)) => { - log::warn!("received invalid RPC message {:?}", error); - } - Either::Right(_) => break, - } - } + }) + .detach(); + + outgoing + .lock() + .await + .insert(connection_id, MessageStream::new(Box::pin(conn_tx))); + + connection_id } pub fn request( - self: &Arc, + &self, + connection_id: ConnectionId, req: T, ) -> impl Future> { - let this = self.clone(); + let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); + let outgoing = self.outgoing.clone(); + let response_channels = self.response_channels.clone(); + let (tx, mut rx) = oneshot::channel(); async move { - let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); - let (tx, mut rx) = oneshot::channel(); - this.response_channels.lock().await.insert(message_id, tx); - this.outgoing + response_channels.lock().await.insert(message_id, tx); + outgoing .lock() .await + .get_mut(&connection_id) + .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))? .write_message(&req.into_envelope(message_id, None)) .await?; let response = rx @@ -115,15 +129,18 @@ impl RpcClient { } pub fn send( - self: &Arc, + &self, + connection_id: ConnectionId, message: T, ) -> impl Future> { - let this = self.clone(); + let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); + let outgoing = self.outgoing.clone(); async move { - let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); - this.outgoing + outgoing .lock() .await + .get_mut(&connection_id) + .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))? .write_message(&message.into_envelope(message_id, None)) .await?; Ok(()) @@ -152,12 +169,16 @@ mod tests { let (server_conn, _) = listener.accept().await.unwrap(); let mut server_stream = MessageStream::new(server_conn); - let client = RpcClient::new(client_conn, executor.clone()); + let client = RpcClient::new(); + let connection_id = client.connect(client_conn, executor.clone()).await; - let client_req = client.request(proto::Auth { - user_id: 42, - access_token: "token".to_string(), - }); + let client_req = client.request( + connection_id, + proto::Auth { + user_id: 42, + access_token: "token".to_string(), + }, + ); smol::pin!(client_req); let server_req = send_recv(&mut client_req, server_stream.read_message()) .await @@ -206,7 +227,8 @@ mod tests { let client_conn = UnixStream::connect(&socket_path).await.unwrap(); let (mut server_conn, _) = listener.accept().await.unwrap(); - let client = RpcClient::new(client_conn, executor.clone()); + let client = RpcClient::new(); + client.connect(client_conn, executor.clone()).await; drop(client); // Try sending an empty payload over and over, until the client is dropped and hangs up. @@ -231,12 +253,16 @@ mod tests { let mut client_conn = UnixStream::connect(&socket_path).await.unwrap(); client_conn.close().await.unwrap(); - let client = RpcClient::new(client_conn, executor.clone()); + let client = RpcClient::new(); + let conn_id = client.connect(client_conn, executor.clone()).await; let err = client - .request(proto::Auth { - user_id: 42, - access_token: "token".to_string(), - }) + .request( + conn_id, + proto::Auth { + user_id: 42, + access_token: "token".to_string(), + }, + ) .await .unwrap_err(); assert_eq!( diff --git a/zed/src/workspace.rs b/zed/src/workspace.rs index 6c87adff160d593f1f8c67ef371060d57a5f140f..ef75aa702a9e26e070dae605b6791409cb6bda89 100644 --- a/zed/src/workspace.rs +++ b/zed/src/workspace.rs @@ -670,13 +670,17 @@ impl Workspace { // a TLS stream using `native-tls`. let stream = smol::net::TcpStream::connect(rpc_address).await?; - let rpc_client = RpcClient::new(stream, executor); + let rpc_client = RpcClient::new(); + let connection_id = rpc_client.connect(stream, executor).await; let auth_response = rpc_client - .request(proto::Auth { - user_id: user_id.parse::()?, - access_token, - }) + .request( + connection_id, + proto::Auth { + user_id: user_id.parse::()?, + access_token, + }, + ) .await?; if !auth_response.credentials_valid { Err(anyhow!("failed to authenticate with RPC server"))?; @@ -684,7 +688,9 @@ impl Workspace { let share_task = this.update(&mut cx, |this, cx| { let worktree = this.worktrees.iter().next()?; - Some(worktree.update(cx, |worktree, cx| worktree.share(rpc_client, cx))) + Some(worktree.update(cx, |worktree, cx| { + worktree.share(rpc_client, connection_id, cx) + })) }); if let Some(share_task) = share_task { diff --git a/zed/src/worktree.rs b/zed/src/worktree.rs index 61b0e81276f61508d49a739c50c7fe8deb9d8cd8..40856a282f5bfdba5a7bb568df5311d4249f1aa8 100644 --- a/zed/src/worktree.rs +++ b/zed/src/worktree.rs @@ -4,7 +4,7 @@ mod ignore; use crate::{ editor::{History, Rope}, - rpc_client::RpcClient, + rpc_client::{ConnectionId, RpcClient}, sum_tree::{self, Cursor, Edit, SumTree}, util::Bias, }; @@ -229,6 +229,7 @@ impl Worktree { pub fn share( &mut self, client: Arc, + connection_id: ConnectionId, cx: &mut ModelContext, ) -> Task> { self.rpc_client = Some(client.clone()); @@ -245,9 +246,12 @@ impl Worktree { .await; let share_response = client - .request(proto::ShareWorktree { - worktree: Some(proto::Worktree { paths }), - }) + .request( + connection_id, + proto::ShareWorktree { + worktree: Some(proto::Worktree { paths }), + }, + ) .await?; log::info!("sharing worktree {:?}", share_response);