From 4aab70d7fb81aa091528c8412aa246a52be8bc7b Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 9 Jul 2021 21:57:10 -0600 Subject: [PATCH] Remove authentication at the RPC layer This means we can remove IOHandler and return a simple future that is Send --- zed-rpc/src/peer.rs | 115 ++++++++++++++++---------------------------- zed/src/rpc.rs | 28 ++--------- 2 files changed, 45 insertions(+), 98 deletions(-) diff --git a/zed-rpc/src/peer.rs b/zed-rpc/src/peer.rs index d825340db0e374ee971ac503beab4a480f41a158..daf83e2f1ba7bee5a125d4d907e7a88c268a2203 100644 --- a/zed-rpc/src/peer.rs +++ b/zed-rpc/src/peer.rs @@ -4,7 +4,6 @@ use async_lock::{Mutex, RwLock}; use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; use futures::{ future::{BoxFuture, LocalBoxFuture}, - stream::{SplitSink, SplitStream}, FutureExt, StreamExt, }; use postage::{ @@ -87,14 +86,6 @@ struct Connection { response_channels: Arc>>>, } -pub struct IOHandler { - connection_id: ConnectionId, - incoming_tx: mpsc::Sender, - outgoing_rx: mpsc::Receiver, - writer: MessageStream, - reader: MessageStream, -} - impl Peer { pub fn new() -> Arc { Arc::new(Self { @@ -109,7 +100,7 @@ impl Peer { router: Arc>, ) -> ( ConnectionId, - IOHandler, SplitStream>, + impl Future> + Send, impl Future>, ) where @@ -117,6 +108,7 @@ impl Peer { Fut: Future, Conn: futures::Sink + futures::Stream> + + Send + Unpin, { let (tx, rx) = conn.split(); @@ -124,19 +116,44 @@ impl Peer { self.next_connection_id .fetch_add(1, atomic::Ordering::SeqCst), ); - let (incoming_tx, mut incoming_rx) = mpsc::channel(64); - let (outgoing_tx, outgoing_rx) = mpsc::channel(64); + let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64); + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64); let connection = Connection { outgoing_tx, next_message_id: Default::default(), response_channels: Default::default(), }; - let handle_io = IOHandler { - connection_id, - outgoing_rx, - incoming_tx, - writer: MessageStream::new(tx), - reader: MessageStream::new(rx), + let mut writer = MessageStream::new(tx); + let mut reader = MessageStream::new(rx); + + let handle_io = async move { + loop { + let read_message = reader.read_message().fuse(); + futures::pin_mut!(read_message); + loop { + futures::select_biased! { + incoming = read_message => match incoming { + Ok(incoming) => { + if incoming_tx.send(incoming).await.is_err() { + return Ok(()); + } + break; + } + Err(error) => { + Err(error).context("received invalid RPC message")?; + } + }, + outgoing = outgoing_rx.recv().fuse() => match outgoing { + Some(outgoing) => { + if let Err(result) = writer.write_message(&outgoing).await { + Err(result).context("failed to write RPC message")?; + } + } + None => return Ok(()), + } + } + } + } }; let response_channels = connection.response_channels.clone(); @@ -402,56 +419,6 @@ impl ForegroundRouter { } } -impl IOHandler -where - W: futures::Sink + Unpin, - R: futures::Stream> + Unpin, -{ - pub async fn run(mut self) -> Result<()> { - loop { - let read_message = self.reader.read_message().fuse(); - futures::pin_mut!(read_message); - loop { - futures::select_biased! { - incoming = read_message => match incoming { - Ok(incoming) => { - if self.incoming_tx.send(incoming).await.is_err() { - return Ok(()); - } - break; - } - Err(error) => { - Err(error).context("received invalid RPC message")?; - } - }, - outgoing = self.outgoing_rx.recv().fuse() => match outgoing { - Some(outgoing) => { - if let Err(result) = self.writer.write_message(&outgoing).await { - Err(result).context("failed to write RPC message")?; - } - } - None => return Ok(()), - } - } - } - } - } - - pub async fn receive(&mut self) -> Result> { - let envelope = self.reader.read_message().await?; - let original_sender_id = envelope.original_sender_id; - let message_id = envelope.id; - let payload = - M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?; - Ok(TypedEnvelope { - sender_id: self.connection_id, - original_sender_id: original_sender_id.map(PeerId), - message_id, - payload, - }) - } -} - impl Clone for Receipt { fn clone(&self) -> Self { Self { @@ -583,10 +550,10 @@ mod tests { .add_connection(server_to_client_2_conn, router.clone()) .await; - smol::spawn(io_task1.run()).detach(); - smol::spawn(io_task2.run()).detach(); - smol::spawn(io_task3.run()).detach(); - smol::spawn(io_task4.run()).detach(); + smol::spawn(io_task1).detach(); + smol::spawn(io_task2).detach(); + smol::spawn(io_task3).detach(); + smol::spawn(io_task4).detach(); smol::spawn(msg_task1).detach(); smol::spawn(msg_task2).detach(); smol::spawn(msg_task3).detach(); @@ -683,7 +650,7 @@ mod tests { let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel(); smol::spawn(async move { - io_handler.run().await.ok(); + io_handler.await.ok(); io_ended_tx.send(()).await.unwrap(); }) .detach(); @@ -717,7 +684,7 @@ mod tests { let router = Arc::new(Router::new()); let (connection_id, io_handler, message_handler) = client.add_connection(client_conn, router).await; - smol::spawn(io_handler.run()).detach(); + smol::spawn(io_handler).detach(); smol::spawn(message_handler).detach(); let err = client diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 52178b4f0e93e80550c1a3a57be2b86339b5957e..286a6a6c158b6cf73ad2b34105cb87c16a43ec3d 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -91,7 +91,7 @@ impl Client { } let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?; - let user_id = user_id.parse()?; + let user_id: i32 = user_id.parse()?; let request = Request::builder().header("Authorization", format!("{} {}", user_id, access_token)); @@ -102,15 +102,13 @@ impl Client { .await .context("websocket handshake")?; log::info!("connected to rpc address {}", *ZED_SERVER_URL); - self.add_connection(stream, user_id, access_token, router, cx) - .await?; + self.add_connection(stream, router, cx).await?; } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") { let stream = smol::net::TcpStream::connect(host).await?; let request = request.uri(format!("ws://{}/rpc", host)).body(())?; let (stream, _) = async_tungstenite::client_async(request, stream).await?; log::info!("connected to rpc address {}", *ZED_SERVER_URL); - self.add_connection(stream, user_id, access_token, router, cx) - .await?; + self.add_connection(stream, router, cx).await?; } else { return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?; }; @@ -121,8 +119,6 @@ impl Client { pub async fn add_connection( &self, conn: Conn, - user_id: i32, - access_token: String, router: Arc, cx: AsyncAppContext, ) -> surf::Result<()> @@ -138,27 +134,11 @@ impl Client { cx.foreground().spawn(handle_messages).detach(); cx.background() .spawn(async move { - if let Err(error) = handle_io.run().await { + if let Err(error) = handle_io.await { log::error!("connection error: {:?}", error); } }) .detach(); - - let auth_response = self - .peer - .request( - connection_id, - proto::Auth { - user_id, - access_token, - }, - ) - .await - .context("rpc auth request failed")?; - if !auth_response.credentials_valid { - Err(anyhow!("failed to authenticate with RPC server"))?; - } - self.state.write().await.connection_id = Some(connection_id); Ok(()) }