WIP: Clear cached credentials if authentication fails

Nathan Sobo and Max Brunsfeld created

Still need to actually handle an HTTP response from the server indicating there was an invalid token.

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

Cargo.lock        |  9 ++--
server/src/rpc.rs | 20 +++++----
zed/Cargo.toml    |  1 
zed/src/rpc.rs    | 91 ++++++++++++++++++++++++++++++++++++++++--------
zed/src/test.rs   | 49 ++++++++++++++++++--------
zrpc/src/conn.rs  |  4 +-
zrpc/src/lib.rs   |  2 
zrpc/src/peer.rs  | 40 ++++++++++----------
8 files changed, 149 insertions(+), 67 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5108,18 +5108,18 @@ dependencies = [
 
 [[package]]
 name = "thiserror"
-version = "1.0.24"
+version = "1.0.29"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e0f4a65597094d4483ddaed134f409b2cb7c1beccf25201a9f73c719254fa98e"
+checksum = "602eca064b2d83369e2b2f34b09c70b605402801927c65c11071ac911d299b88"
 dependencies = [
  "thiserror-impl",
 ]
 
 [[package]]
 name = "thiserror-impl"
-version = "1.0.24"
+version = "1.0.29"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7765189610d8241a44529806d6fd1f2e0a08734313a35d5b3a556f92b381f3c0"
+checksum = "bad553cc2c78e8de258400763a647e80e6d1b31ee237275d756f6836d204494c"
 dependencies = [
  "proc-macro2",
  "quote",
@@ -5914,6 +5914,7 @@ dependencies = [
  "smol",
  "surf",
  "tempdir",
+ "thiserror",
  "time 0.3.2",
  "tiny_http",
  "toml 0.5.8",

server/src/rpc.rs 🔗

@@ -27,7 +27,7 @@ use time::OffsetDateTime;
 use zrpc::{
     auth::random_token,
     proto::{self, AnyTypedEnvelope, EnvelopedMessage},
-    Conn, ConnectionId, Peer, TypedEnvelope,
+    Connection, ConnectionId, Peer, TypedEnvelope,
 };
 
 type ReplicaId = u16;
@@ -48,13 +48,13 @@ pub struct Server {
 
 #[derive(Default)]
 struct ServerState {
-    connections: HashMap<ConnectionId, Connection>,
+    connections: HashMap<ConnectionId, ConnectionState>,
     pub worktrees: HashMap<u64, Worktree>,
     channels: HashMap<ChannelId, Channel>,
     next_worktree_id: u64,
 }
 
-struct Connection {
+struct ConnectionState {
     user_id: UserId,
     worktrees: HashSet<u64>,
     channels: HashSet<ChannelId>,
@@ -133,7 +133,7 @@ impl Server {
 
     pub fn handle_connection(
         self: &Arc<Self>,
-        connection: Conn,
+        connection: Connection,
         addr: String,
         user_id: UserId,
     ) -> impl Future<Output = ()> {
@@ -211,7 +211,7 @@ impl Server {
     async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
         self.state.write().await.connections.insert(
             connection_id,
-            Connection {
+            ConnectionState {
                 user_id,
                 worktrees: Default::default(),
                 channels: Default::default(),
@@ -972,7 +972,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
             let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
             task::spawn(async move {
                 if let Some(stream) = upgrade_receiver.await {
-                    server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
+                    server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
                 }
             });
 
@@ -1023,7 +1023,7 @@ mod tests {
         editor::{Editor, Insert},
         fs::{FakeFs, Fs as _},
         language::LanguageRegistry,
-        rpc::{self, Client, Credentials},
+        rpc::{self, Client, Credentials, EstablishConnectionError},
         settings,
         test::FakeHttpClient,
         user::UserStore,
@@ -1941,9 +1941,11 @@ mod tests {
                     let client_name = client_name.clone();
                     cx.spawn(move |cx| async move {
                         if forbid_connections.load(SeqCst) {
-                            Err(anyhow!("server is forbidding connections"))
+                            Err(EstablishConnectionError::other(anyhow!(
+                                "server is forbidding connections"
+                            )))
                         } else {
-                            let (client_conn, server_conn, kill_conn) = Conn::in_memory();
+                            let (client_conn, server_conn, kill_conn) = Connection::in_memory();
                             connection_killers.lock().insert(client_user_id, kill_conn);
                             cx.background()
                                 .spawn(server.handle_connection(

zed/Cargo.toml 🔗

@@ -50,6 +50,7 @@ smallvec = { version = "1.6", features = ["union"] }
 smol = "1.2.5"
 surf = "2.2"
 tempdir = { version = "0.3.7", optional = true }
+thiserror = "1.0.29"
 time = { version = "0.3" }
 tiny_http = "0.8"
 toml = "0.5"

zed/src/rpc.rs 🔗

@@ -15,10 +15,11 @@ use std::{
     time::{Duration, Instant},
 };
 use surf::Url;
+use thiserror::Error;
 pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
 use zrpc::{
     proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
-    Conn, Peer, Receipt,
+    Connection, Peer, Receipt,
 };
 
 lazy_static! {
@@ -32,10 +33,32 @@ pub struct Client {
     authenticate:
         Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
     establish_connection: Option<
-        Box<dyn 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>>,
+        Box<
+            dyn 'static
+                + Send
+                + Sync
+                + Fn(
+                    &Credentials,
+                    &AsyncAppContext,
+                ) -> Task<Result<Connection, EstablishConnectionError>>,
+        >,
     >,
 }
 
+#[derive(Error, Debug)]
+pub enum EstablishConnectionError {
+    #[error("invalid access token")]
+    InvalidAccessToken,
+    #[error("{0}")]
+    Other(anyhow::Error),
+}
+
+impl EstablishConnectionError {
+    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
+        Self::Other(error.into())
+    }
+}
+
 #[derive(Copy, Clone, Debug)]
 pub enum Status {
     SignedOut,
@@ -122,7 +145,10 @@ impl Client {
     #[cfg(any(test, feature = "test-support"))]
     pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
     where
-        F: 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>,
+        F: 'static
+            + Send
+            + Sync
+            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
     {
         self.establish_connection = Some(Box::new(connect));
         self
@@ -288,13 +314,18 @@ impl Client {
                 Ok(())
             }
             Err(err) => {
+                eprintln!("error in authenticate and connect {}", err);
+                if matches!(err, EstablishConnectionError::InvalidAccessToken) {
+                    eprintln!("nuking credentials");
+                    self.state.write().credentials.take();
+                }
                 self.set_status(Status::ConnectionError, cx);
-                Err(err)
+                Err(err)?
             }
         }
     }
 
-    async fn set_connection(self: &Arc<Self>, conn: Conn, cx: &AsyncAppContext) {
+    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
         let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
         cx.foreground()
             .spawn({
@@ -359,7 +390,7 @@ impl Client {
         self: &Arc<Self>,
         credentials: &Credentials,
         cx: &AsyncAppContext,
-    ) -> Task<Result<Conn>> {
+    ) -> Task<Result<Connection, EstablishConnectionError>> {
         if let Some(callback) = self.establish_connection.as_ref() {
             callback(credentials, cx)
         } else {
@@ -371,28 +402,43 @@ impl Client {
         self: &Arc<Self>,
         credentials: &Credentials,
         cx: &AsyncAppContext,
-    ) -> Task<Result<Conn>> {
+    ) -> Task<Result<Connection, EstablishConnectionError>> {
         let request = Request::builder().header(
             "Authorization",
             format!("{} {}", credentials.user_id, credentials.access_token),
         );
         cx.background().spawn(async move {
             if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
-                let stream = smol::net::TcpStream::connect(host).await?;
-                let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
+                let stream = smol::net::TcpStream::connect(host)
+                    .await
+                    .map_err(EstablishConnectionError::other)?;
+                let request = request
+                    .uri(format!("wss://{}/rpc", host))
+                    .body(())
+                    .map_err(EstablishConnectionError::other)?;
                 let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
                     .await
-                    .context("websocket handshake")?;
-                Ok(Conn::new(stream))
+                    .context("websocket handshake")
+                    .map_err(EstablishConnectionError::other)?;
+                Ok(Connection::new(stream))
             } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
-                let stream = smol::net::TcpStream::connect(host).await?;
-                let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
+                let stream = smol::net::TcpStream::connect(host)
+                    .await
+                    .map_err(EstablishConnectionError::other)?;
+                let request = request
+                    .uri(format!("ws://{}/rpc", host))
+                    .body(())
+                    .map_err(EstablishConnectionError::other)?;
                 let (stream, _) = async_tungstenite::client_async(request, stream)
                     .await
-                    .context("websocket handshake")?;
-                Ok(Conn::new(stream))
+                    .context("websocket handshake")
+                    .map_err(EstablishConnectionError::other)?;
+                Ok(Connection::new(stream))
             } else {
-                Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
+                Err(EstablishConnectionError::other(anyhow!(
+                    "invalid server url: {}",
+                    *ZED_SERVER_URL
+                )))
             }
         })
     }
@@ -591,6 +637,19 @@ mod tests {
         cx.foreground().advance_clock(Duration::from_secs(10));
         while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
         assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
+
+        server.forbid_connections();
+        server.disconnect().await;
+        while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
+
+        // Clear cached credentials after authentication fails
+        server.roll_access_token();
+        server.allow_connections();
+        cx.foreground().advance_clock(Duration::from_secs(10));
+        assert_eq!(server.auth_count(), 1);
+        cx.foreground().advance_clock(Duration::from_secs(10));
+        while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
+        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
     }
 
     #[test]

zed/src/test.rs 🔗

@@ -4,7 +4,7 @@ use crate::{
     fs::RealFs,
     http::{HttpClient, Request, Response, ServerResponse},
     language::LanguageRegistry,
-    rpc::{self, Client, Credentials},
+    rpc::{self, Client, Credentials, EstablishConnectionError},
     settings::{self, ThemeRegistry},
     time::ReplicaId,
     user::UserStore,
@@ -26,7 +26,7 @@ use std::{
     },
 };
 use tempdir::TempDir;
-use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
+use zrpc::{proto, Connection, ConnectionId, Peer, Receipt, TypedEnvelope};
 
 #[cfg(test)]
 #[ctor::ctor]
@@ -210,6 +210,8 @@ pub struct FakeServer {
     connection_id: Mutex<Option<ConnectionId>>,
     forbid_connections: AtomicBool,
     auth_count: AtomicUsize,
+    access_token: AtomicUsize,
+    user_id: u64,
 }
 
 impl FakeServer {
@@ -224,6 +226,8 @@ impl FakeServer {
             connection_id: Default::default(),
             forbid_connections: Default::default(),
             auth_count: Default::default(),
+            access_token: Default::default(),
+            user_id: client_user_id,
         });
 
         Arc::get_mut(client)
@@ -232,8 +236,8 @@ impl FakeServer {
                 let server = server.clone();
                 move |cx| {
                     server.auth_count.fetch_add(1, SeqCst);
+                    let access_token = server.access_token.load(SeqCst).to_string();
                     cx.spawn(move |_| async move {
-                        let access_token = "the-token".to_string();
                         Ok(Credentials {
                             user_id: client_user_id,
                             access_token,
@@ -244,11 +248,10 @@ impl FakeServer {
             .override_establish_connection({
                 let server = server.clone();
                 move |credentials, cx| {
-                    assert_eq!(credentials.user_id, client_user_id);
-                    assert_eq!(credentials.access_token, "the-token");
+                    let credentials = credentials.clone();
                     cx.spawn({
                         let server = server.clone();
-                        move |cx| async move { server.connect(&cx).await }
+                        move |cx| async move { server.establish_connection(&credentials, &cx).await }
                     })
                 }
             });
@@ -266,23 +269,39 @@ impl FakeServer {
         self.incoming.lock().take();
     }
 
-    async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
+    async fn establish_connection(
+        &self,
+        credentials: &Credentials,
+        cx: &AsyncAppContext,
+    ) -> Result<Connection, EstablishConnectionError> {
+        assert_eq!(credentials.user_id, self.user_id);
+
         if self.forbid_connections.load(SeqCst) {
-            Err(anyhow!("server is forbidding connections"))
-        } else {
-            let (client_conn, server_conn, _) = Conn::in_memory();
-            let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
-            cx.background().spawn(io).detach();
-            *self.incoming.lock() = Some(incoming);
-            *self.connection_id.lock() = Some(connection_id);
-            Ok(client_conn)
+            Err(EstablishConnectionError::Other(anyhow!(
+                "server is forbidding connections"
+            )))?
+        }
+
+        if credentials.access_token != self.access_token.load(SeqCst).to_string() {
+            Err(EstablishConnectionError::InvalidAccessToken)?
         }
+
+        let (client_conn, server_conn, _) = Connection::in_memory();
+        let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
+        cx.background().spawn(io).detach();
+        *self.incoming.lock() = Some(incoming);
+        *self.connection_id.lock() = Some(connection_id);
+        Ok(client_conn)
     }
 
     pub fn auth_count(&self) -> usize {
         self.auth_count.load(SeqCst)
     }
 
+    pub fn roll_access_token(&self) {
+        self.access_token.fetch_add(1, SeqCst);
+    }
+
     pub fn forbid_connections(&self) {
         self.forbid_connections.store(true, SeqCst);
     }

zrpc/src/conn.rs 🔗

@@ -2,7 +2,7 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSock
 use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
 use std::{io, task::Poll};
 
-pub struct Conn {
+pub struct Connection {
     pub(crate) tx:
         Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
     pub(crate) rx: Box<
@@ -13,7 +13,7 @@ pub struct Conn {
     >,
 }
 
-impl Conn {
+impl Connection {
     pub fn new<S>(stream: S) -> Self
     where
         S: 'static

zrpc/src/lib.rs 🔗

@@ -2,5 +2,5 @@ pub mod auth;
 mod conn;
 mod peer;
 pub mod proto;
-pub use conn::Conn;
+pub use conn::Connection;
 pub use peer::*;

zrpc/src/peer.rs 🔗

@@ -1,5 +1,5 @@
 use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
-use super::Conn;
+use super::Connection;
 use anyhow::{anyhow, Context, Result};
 use async_lock::{Mutex, RwLock};
 use futures::FutureExt as _;
@@ -79,12 +79,12 @@ impl<T: RequestMessage> TypedEnvelope<T> {
 }
 
 pub struct Peer {
-    connections: RwLock<HashMap<ConnectionId, Connection>>,
+    connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
     next_connection_id: AtomicU32,
 }
 
 #[derive(Clone)]
-struct Connection {
+struct ConnectionState {
     outgoing_tx: mpsc::Sender<proto::Envelope>,
     next_message_id: Arc<AtomicU32>,
     response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
@@ -100,7 +100,7 @@ impl Peer {
 
     pub async fn add_connection(
         self: &Arc<Self>,
-        conn: Conn,
+        connection: Connection,
     ) -> (
         ConnectionId,
         impl Future<Output = anyhow::Result<()>> + Send,
@@ -112,16 +112,16 @@ impl Peer {
         );
         let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
         let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
-        let connection = Connection {
+        let connection_state = ConnectionState {
             outgoing_tx,
             next_message_id: Default::default(),
             response_channels: Default::default(),
         };
-        let mut writer = MessageStream::new(conn.tx);
-        let mut reader = MessageStream::new(conn.rx);
+        let mut writer = MessageStream::new(connection.tx);
+        let mut reader = MessageStream::new(connection.rx);
 
         let this = self.clone();
-        let response_channels = connection.response_channels.clone();
+        let response_channels = connection_state.response_channels.clone();
         let handle_io = async move {
             loop {
                 let read_message = reader.read_message().fuse();
@@ -179,7 +179,7 @@ impl Peer {
         self.connections
             .write()
             .await
-            .insert(connection_id, connection);
+            .insert(connection_id, connection_state);
 
         (connection_id, handle_io, incoming_rx)
     }
@@ -218,7 +218,7 @@ impl Peer {
         let this = self.clone();
         let (tx, mut rx) = mpsc::channel(1);
         async move {
-            let mut connection = this.connection(receiver_id).await?;
+            let mut connection = this.connection_state(receiver_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -252,7 +252,7 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let mut connection = this.connection(receiver_id).await?;
+            let mut connection = this.connection_state(receiver_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -272,7 +272,7 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let mut connection = this.connection(receiver_id).await?;
+            let mut connection = this.connection_state(receiver_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -291,7 +291,7 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let mut connection = this.connection(receipt.sender_id).await?;
+            let mut connection = this.connection_state(receipt.sender_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -310,7 +310,7 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let mut connection = this.connection(receipt.sender_id).await?;
+            let mut connection = this.connection_state(receipt.sender_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -322,10 +322,10 @@ impl Peer {
         }
     }
 
-    fn connection(
+    fn connection_state(
         self: &Arc<Self>,
         connection_id: ConnectionId,
-    ) -> impl Future<Output = Result<Connection>> {
+    ) -> impl Future<Output = Result<ConnectionState>> {
         let this = self.clone();
         async move {
             let connections = this.connections.read().await;
@@ -352,12 +352,12 @@ mod tests {
             let client1 = Peer::new();
             let client2 = Peer::new();
 
-            let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
+            let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
             let (client1_conn_id, io_task1, _) =
                 client1.add_connection(client1_to_server_conn).await;
             let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
 
-            let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
+            let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
             let (client2_conn_id, io_task3, _) =
                 client2.add_connection(client2_to_server_conn).await;
             let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
@@ -486,7 +486,7 @@ mod tests {
     #[test]
     fn test_disconnect() {
         smol::block_on(async move {
-            let (client_conn, mut server_conn, _) = Conn::in_memory();
+            let (client_conn, mut server_conn, _) = Connection::in_memory();
 
             let client = Peer::new();
             let (connection_id, io_handler, mut incoming) =
@@ -520,7 +520,7 @@ mod tests {
     #[test]
     fn test_io_error() {
         smol::block_on(async move {
-            let (client_conn, server_conn, _) = Conn::in_memory();
+            let (client_conn, server_conn, _) = Connection::in_memory();
             drop(server_conn);
 
             let client = Peer::new();