Avoid verifying access tokens for out-of-date clients

Max Brunsfeld created

Replace the 'VerifyToken' middleware with a 'process_auth_header' function
that we call in the '/rpc' handler after checking that the client's protocol
version matches.

Change summary

server/src/auth.rs | 83 +++++++++++++++++++++++------------------------
server/src/rpc.rs  | 18 +++++++--
2 files changed, 53 insertions(+), 48 deletions(-)

Detailed changes

server/src/auth.rs 🔗

@@ -18,7 +18,7 @@ use scrypt::{
 use serde::{Deserialize, Serialize};
 use std::{borrow::Cow, convert::TryFrom, sync::Arc};
 use surf::{StatusCode, Url};
-use tide::{log, Server};
+use tide::{log, Error, Server};
 use zrpc::auth as zed_auth;
 
 static CURRENT_GITHUB_USER: &'static str = "current_github_user";
@@ -33,51 +33,48 @@ pub struct User {
     pub is_admin: bool,
 }
 
-pub struct VerifyToken;
-
-#[async_trait]
-impl tide::Middleware<Arc<AppState>> for VerifyToken {
-    async fn handle(
-        &self,
-        mut request: Request,
-        next: tide::Next<'_, Arc<AppState>>,
-    ) -> tide::Result {
-        let mut auth_header = request
-            .header("Authorization")
-            .ok_or_else(|| anyhow!("no authorization header"))?
-            .last()
-            .as_str()
-            .split_whitespace();
-
-        let user_id = UserId(
-            auth_header
-                .next()
-                .ok_or_else(|| anyhow!("missing user id in authorization header"))?
-                .parse()?,
-        );
-        let access_token = auth_header
-            .next()
-            .ok_or_else(|| anyhow!("missing access token in authorization header"))?;
-
-        let state = request.state().clone();
-
-        let mut credentials_valid = false;
-        for password_hash in state.db.get_access_token_hashes(user_id).await? {
-            if verify_access_token(&access_token, &password_hash)? {
-                credentials_valid = true;
-                break;
-            }
+pub async fn process_auth_header(request: &Request) -> tide::Result<UserId> {
+    let mut auth_header = request
+        .header("Authorization")
+        .ok_or_else(|| {
+            Error::new(
+                StatusCode::BadRequest,
+                anyhow!("missing authorization header"),
+            )
+        })?
+        .last()
+        .as_str()
+        .split_whitespace();
+    let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
+        Error::new(
+            StatusCode::BadRequest,
+            anyhow!("missing user id in authorization header"),
+        )
+    })?);
+    let access_token = auth_header.next().ok_or_else(|| {
+        Error::new(
+            StatusCode::BadRequest,
+            anyhow!("missing access token in authorization header"),
+        )
+    })?;
+
+    let state = request.state().clone();
+    let mut credentials_valid = false;
+    for password_hash in state.db.get_access_token_hashes(user_id).await? {
+        if verify_access_token(&access_token, &password_hash)? {
+            credentials_valid = true;
+            break;
         }
+    }
 
-        if credentials_valid {
-            request.set_ext(user_id);
-            Ok(next.run(request).await)
-        } else {
-            let mut response = tide::Response::new(StatusCode::Unauthorized);
-            response.set_body("invalid credentials");
-            Ok(response)
-        }
+    if !credentials_valid {
+        Err(Error::new(
+            StatusCode::Unauthorized,
+            anyhow!("invalid credentials"),
+        ))?;
     }
+
+    Ok(user_id)
 }
 
 #[async_trait]

server/src/rpc.rs 🔗

@@ -1,7 +1,7 @@
 mod store;
 
 use super::{
-    auth,
+    auth::process_auth_header,
     db::{ChannelId, MessageId, UserId},
     AppState,
 };
@@ -885,8 +885,7 @@ where
 
 pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
     let server = Server::new(app.state().clone(), rpc.clone(), None);
-    app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
-        let user_id = request.ext::<UserId>().copied();
+    app.at("/rpc").get(move |request: Request<Arc<AppState>>| {
         let server = server.clone();
         async move {
             const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@@ -907,6 +906,8 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
                 None => return Err(anyhow!("expected sec-websocket-key"))?,
             };
 
+            let user_id = process_auth_header(&request).await?;
+
             let mut response = Response::new(StatusCode::SwitchingProtocols);
             response.insert_header(UPGRADE, "websocket");
             response.insert_header(CONNECTION, "Upgrade");
@@ -917,10 +918,17 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
             let http_res: &mut tide::http::Response = response.as_mut();
             let upgrade_receiver = http_res.recv_upgrade().await;
             let addr = request.remote().unwrap_or("unknown").to_string();
-            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
             task::spawn(async move {
                 if let Some(stream) = upgrade_receiver.await {
-                    server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
+                    server
+                        .handle_connection(
+                            Connection::new(
+                                WebSocketStream::from_raw_socket(stream, Role::Server, None).await,
+                            ),
+                            addr,
+                            user_id,
+                        )
+                        .await;
                 }
             });