Clear cached credentials when establishing a websocket connection with an invalid token

Nathan Sobo created

Change summary

gpui/src/platform.rs              |  1 
gpui/src/platform/mac/platform.rs | 20 ++++++++++
gpui/src/platform/test.rs         |  4 ++
server/src/auth.rs                |  6 ++-
zed/src/rpc.rs                    | 65 ++++++++++++++++----------------
zed/src/test.rs                   |  2 
zed/src/workspace.rs              |  2 
7 files changed, 63 insertions(+), 37 deletions(-)

Detailed changes

gpui/src/platform.rs 🔗

@@ -48,6 +48,7 @@ pub trait Platform: Send + Sync {
 
     fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Result<()>;
     fn read_credentials(&self, url: &str) -> Result<Option<(String, Vec<u8>)>>;
+    fn delete_credentials(&self, url: &str) -> Result<()>;
 
     fn set_cursor_style(&self, style: CursorStyle);
 

gpui/src/platform/mac/platform.rs 🔗

@@ -551,6 +551,25 @@ impl platform::Platform for MacPlatform {
         }
     }
 
+    fn delete_credentials(&self, url: &str) -> Result<()> {
+        let url = CFString::from(url);
+
+        unsafe {
+            use security::*;
+
+            let mut query_attrs = CFMutableDictionary::with_capacity(2);
+            query_attrs.set(kSecClass as *const _, kSecClassInternetPassword as *const _);
+            query_attrs.set(kSecAttrServer as *const _, url.as_CFTypeRef());
+
+            let status = SecItemDelete(query_attrs.as_concrete_TypeRef());
+
+            if status != errSecSuccess {
+                return Err(anyhow!("delete password failed: {}", status));
+            }
+        }
+        Ok(())
+    }
+
     fn set_cursor_style(&self, style: CursorStyle) {
         unsafe {
             let cursor: id = match style {
@@ -676,6 +695,7 @@ mod security {
 
         pub fn SecItemAdd(attributes: CFDictionaryRef, result: *mut CFTypeRef) -> OSStatus;
         pub fn SecItemUpdate(query: CFDictionaryRef, attributes: CFDictionaryRef) -> OSStatus;
+        pub fn SecItemDelete(query: CFDictionaryRef) -> OSStatus;
         pub fn SecItemCopyMatching(query: CFDictionaryRef, result: *mut CFTypeRef) -> OSStatus;
     }
 

gpui/src/platform/test.rs 🔗

@@ -137,6 +137,10 @@ impl super::Platform for Platform {
         Ok(None)
     }
 
+    fn delete_credentials(&self, _: &str) -> Result<()> {
+        Ok(())
+    }
+
     fn set_cursor_style(&self, style: CursorStyle) {
         *self.cursor.lock() = style;
     }

server/src/auth.rs 🔗

@@ -17,7 +17,7 @@ use scrypt::{
 };
 use serde::{Deserialize, Serialize};
 use std::{borrow::Cow, convert::TryFrom, sync::Arc};
-use surf::Url;
+use surf::{StatusCode, Url};
 use tide::Server;
 use zrpc::auth as zed_auth;
 
@@ -73,7 +73,9 @@ impl tide::Middleware<Arc<AppState>> for VerifyToken {
             request.set_ext(user_id);
             Ok(next.run(request).await)
         } else {
-            Err(anyhow!("invalid credentials").into())
+            let mut response = tide::Response::new(StatusCode::Unauthorized);
+            response.set_body("invalid credentials");
+            Ok(response)
         }
     }
 }

zed/src/rpc.rs 🔗

@@ -1,6 +1,9 @@
 use crate::util::ResultExt;
 use anyhow::{anyhow, Context, Result};
-use async_tungstenite::tungstenite::http::Request;
+use async_tungstenite::tungstenite::{
+    error::Error as WebsocketError,
+    http::{Request, StatusCode},
+};
 use gpui::{AsyncAppContext, Entity, ModelContext, Task};
 use lazy_static::lazy_static;
 use parking_lot::RwLock;
@@ -47,10 +50,25 @@ pub struct Client {
 
 #[derive(Error, Debug)]
 pub enum EstablishConnectionError {
-    #[error("invalid access token")]
-    InvalidAccessToken,
+    #[error("unauthorized")]
+    Unauthorized,
+    #[error("{0}")]
+    Other(#[from] anyhow::Error),
     #[error("{0}")]
-    Other(anyhow::Error),
+    Io(#[from] std::io::Error),
+    #[error("{0}")]
+    Http(#[from] async_tungstenite::tungstenite::http::Error),
+}
+
+impl From<WebsocketError> for EstablishConnectionError {
+    fn from(error: WebsocketError) -> Self {
+        if let WebsocketError::Http(response) = &error {
+            if response.status() == StatusCode::UNAUTHORIZED {
+                return EstablishConnectionError::Unauthorized;
+            }
+        }
+        EstablishConnectionError::Other(error.into())
+    }
 }
 
 impl EstablishConnectionError {
@@ -314,10 +332,9 @@ impl Client {
                 Ok(())
             }
             Err(err) => {
-                eprintln!("error in authenticate and connect {}", err);
-                if matches!(err, EstablishConnectionError::InvalidAccessToken) {
-                    eprintln!("nuking credentials");
+                if matches!(err, EstablishConnectionError::Unauthorized) {
                     self.state.write().credentials.take();
+                    cx.platform().delete_credentials(&ZED_SERVER_URL).ok();
                 }
                 self.set_status(Status::ConnectionError, cx);
                 Err(err)?
@@ -409,36 +426,18 @@ impl Client {
         );
         cx.background().spawn(async move {
             if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
-                let stream = smol::net::TcpStream::connect(host)
-                    .await
-                    .map_err(EstablishConnectionError::other)?;
-                let request = request
-                    .uri(format!("wss://{}/rpc", host))
-                    .body(())
-                    .map_err(EstablishConnectionError::other)?;
-                let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
-                    .await
-                    .context("websocket handshake")
-                    .map_err(EstablishConnectionError::other)?;
+                let stream = smol::net::TcpStream::connect(host).await?;
+                let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
+                let (stream, _) =
+                    async_tungstenite::async_tls::client_async_tls(request, stream).await?;
                 Ok(Connection::new(stream))
             } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
-                let stream = smol::net::TcpStream::connect(host)
-                    .await
-                    .map_err(EstablishConnectionError::other)?;
-                let request = request
-                    .uri(format!("ws://{}/rpc", host))
-                    .body(())
-                    .map_err(EstablishConnectionError::other)?;
-                let (stream, _) = async_tungstenite::client_async(request, stream)
-                    .await
-                    .context("websocket handshake")
-                    .map_err(EstablishConnectionError::other)?;
+                let stream = smol::net::TcpStream::connect(host).await?;
+                let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
+                let (stream, _) = async_tungstenite::client_async(request, stream).await?;
                 Ok(Connection::new(stream))
             } else {
-                Err(EstablishConnectionError::other(anyhow!(
-                    "invalid server url: {}",
-                    *ZED_SERVER_URL
-                )))
+                Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?
             }
         })
     }

zed/src/test.rs 🔗

@@ -283,7 +283,7 @@ impl FakeServer {
         }
 
         if credentials.access_token != self.access_token.load(SeqCst).to_string() {
-            Err(EstablishConnectionError::InvalidAccessToken)?
+            Err(EstablishConnectionError::Unauthorized)?
         }
 
         let (client_conn, server_conn, _) = Connection::in_memory();

zed/src/workspace.rs 🔗

@@ -960,7 +960,7 @@ impl Workspace {
 
     fn render_connection_status(&self) -> Option<ElementBox> {
         let theme = &self.settings.borrow().theme;
-        match dbg!(&*self.rpc.status().borrow()) {
+        match &*self.rpc.status().borrow() {
             rpc::Status::ConnectionError
             | rpc::Status::ConnectionLost
             | rpc::Status::Reauthenticating