Remove authentication at the RPC layer

Nathan Sobo created

This means we can remove IOHandler and return a simple future that is Send

Change summary

zed-rpc/src/peer.rs | 115 ++++++++++++++++------------------------------
zed/src/rpc.rs      |  28 +---------
2 files changed, 45 insertions(+), 98 deletions(-)

Detailed changes

zed-rpc/src/peer.rs 🔗

@@ -4,7 +4,6 @@ use async_lock::{Mutex, RwLock};
 use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
 use futures::{
     future::{BoxFuture, LocalBoxFuture},
-    stream::{SplitSink, SplitStream},
     FutureExt, StreamExt,
 };
 use postage::{
@@ -87,14 +86,6 @@ struct Connection {
     response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 }
 
-pub struct IOHandler<W, R> {
-    connection_id: ConnectionId,
-    incoming_tx: mpsc::Sender<proto::Envelope>,
-    outgoing_rx: mpsc::Receiver<proto::Envelope>,
-    writer: MessageStream<W>,
-    reader: MessageStream<R>,
-}
-
 impl Peer {
     pub fn new() -> Arc<Self> {
         Arc::new(Self {
@@ -109,7 +100,7 @@ impl Peer {
         router: Arc<RouterInternal<H>>,
     ) -> (
         ConnectionId,
-        IOHandler<SplitSink<Conn, WebSocketMessage>, SplitStream<Conn>>,
+        impl Future<Output = anyhow::Result<()>> + Send,
         impl Future<Output = anyhow::Result<()>>,
     )
     where
@@ -117,6 +108,7 @@ impl Peer {
         Fut: Future<Output = ()>,
         Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
             + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+            + Send
             + Unpin,
     {
         let (tx, rx) = conn.split();
@@ -124,19 +116,44 @@ impl Peer {
             self.next_connection_id
                 .fetch_add(1, atomic::Ordering::SeqCst),
         );
-        let (incoming_tx, mut incoming_rx) = mpsc::channel(64);
-        let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
+        let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
+        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
         let connection = Connection {
             outgoing_tx,
             next_message_id: Default::default(),
             response_channels: Default::default(),
         };
-        let handle_io = IOHandler {
-            connection_id,
-            outgoing_rx,
-            incoming_tx,
-            writer: MessageStream::new(tx),
-            reader: MessageStream::new(rx),
+        let mut writer = MessageStream::new(tx);
+        let mut reader = MessageStream::new(rx);
+
+        let handle_io = async move {
+            loop {
+                let read_message = reader.read_message().fuse();
+                futures::pin_mut!(read_message);
+                loop {
+                    futures::select_biased! {
+                        incoming = read_message => match incoming {
+                            Ok(incoming) => {
+                                if incoming_tx.send(incoming).await.is_err() {
+                                    return Ok(());
+                                }
+                                break;
+                            }
+                            Err(error) => {
+                                Err(error).context("received invalid RPC message")?;
+                            }
+                        },
+                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
+                            Some(outgoing) => {
+                                if let Err(result) = writer.write_message(&outgoing).await {
+                                    Err(result).context("failed to write RPC message")?;
+                                }
+                            }
+                            None => return Ok(()),
+                        }
+                    }
+                }
+            }
         };
 
         let response_channels = connection.response_channels.clone();
@@ -402,56 +419,6 @@ impl ForegroundRouter {
     }
 }
 
-impl<W, R> IOHandler<W, R>
-where
-    W: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
-    R: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
-{
-    pub async fn run(mut self) -> Result<()> {
-        loop {
-            let read_message = self.reader.read_message().fuse();
-            futures::pin_mut!(read_message);
-            loop {
-                futures::select_biased! {
-                    incoming = read_message => match incoming {
-                        Ok(incoming) => {
-                            if self.incoming_tx.send(incoming).await.is_err() {
-                                return Ok(());
-                            }
-                            break;
-                        }
-                        Err(error) => {
-                            Err(error).context("received invalid RPC message")?;
-                        }
-                    },
-                    outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
-                        Some(outgoing) => {
-                            if let Err(result) = self.writer.write_message(&outgoing).await {
-                                Err(result).context("failed to write RPC message")?;
-                            }
-                        }
-                        None => return Ok(()),
-                    }
-                }
-            }
-        }
-    }
-
-    pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
-        let envelope = self.reader.read_message().await?;
-        let original_sender_id = envelope.original_sender_id;
-        let message_id = envelope.id;
-        let payload =
-            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
-        Ok(TypedEnvelope {
-            sender_id: self.connection_id,
-            original_sender_id: original_sender_id.map(PeerId),
-            message_id,
-            payload,
-        })
-    }
-}
-
 impl<T> Clone for Receipt<T> {
     fn clone(&self) -> Self {
         Self {
@@ -583,10 +550,10 @@ mod tests {
                 .add_connection(server_to_client_2_conn, router.clone())
                 .await;
 
-            smol::spawn(io_task1.run()).detach();
-            smol::spawn(io_task2.run()).detach();
-            smol::spawn(io_task3.run()).detach();
-            smol::spawn(io_task4.run()).detach();
+            smol::spawn(io_task1).detach();
+            smol::spawn(io_task2).detach();
+            smol::spawn(io_task3).detach();
+            smol::spawn(io_task4).detach();
             smol::spawn(msg_task1).detach();
             smol::spawn(msg_task2).detach();
             smol::spawn(msg_task3).detach();
@@ -683,7 +650,7 @@ mod tests {
 
             let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
             smol::spawn(async move {
-                io_handler.run().await.ok();
+                io_handler.await.ok();
                 io_ended_tx.send(()).await.unwrap();
             })
             .detach();
@@ -717,7 +684,7 @@ mod tests {
             let router = Arc::new(Router::new());
             let (connection_id, io_handler, message_handler) =
                 client.add_connection(client_conn, router).await;
-            smol::spawn(io_handler.run()).detach();
+            smol::spawn(io_handler).detach();
             smol::spawn(message_handler).detach();
 
             let err = client

zed/src/rpc.rs 🔗

@@ -91,7 +91,7 @@ impl Client {
         }
 
         let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
-        let user_id = user_id.parse()?;
+        let user_id: i32 = user_id.parse()?;
         let request =
             Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
 
@@ -102,15 +102,13 @@ impl Client {
                 .await
                 .context("websocket handshake")?;
             log::info!("connected to rpc address {}", *ZED_SERVER_URL);
-            self.add_connection(stream, user_id, access_token, router, cx)
-                .await?;
+            self.add_connection(stream, router, cx).await?;
         } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
             let stream = smol::net::TcpStream::connect(host).await?;
             let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
             let (stream, _) = async_tungstenite::client_async(request, stream).await?;
             log::info!("connected to rpc address {}", *ZED_SERVER_URL);
-            self.add_connection(stream, user_id, access_token, router, cx)
-                .await?;
+            self.add_connection(stream, router, cx).await?;
         } else {
             return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
         };
@@ -121,8 +119,6 @@ impl Client {
     pub async fn add_connection<Conn>(
         &self,
         conn: Conn,
-        user_id: i32,
-        access_token: String,
         router: Arc<ForegroundRouter>,
         cx: AsyncAppContext,
     ) -> surf::Result<()>
@@ -138,27 +134,11 @@ impl Client {
         cx.foreground().spawn(handle_messages).detach();
         cx.background()
             .spawn(async move {
-                if let Err(error) = handle_io.run().await {
+                if let Err(error) = handle_io.await {
                     log::error!("connection error: {:?}", error);
                 }
             })
             .detach();
-
-        let auth_response = self
-            .peer
-            .request(
-                connection_id,
-                proto::Auth {
-                    user_id,
-                    access_token,
-                },
-            )
-            .await
-            .context("rpc auth request failed")?;
-        if !auth_response.credentials_valid {
-            Err(anyhow!("failed to authenticate with RPC server"))?;
-        }
-
         self.state.write().await.connection_id = Some(connection_id);
         Ok(())
     }