diff --git a/server/src/auth.rs b/server/src/auth.rs index 0578f2c9fafffe496c5fd8c68a502cc4efe85743..4a06c642eb2554128b06e477c1e934a1fde83b36 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -18,7 +18,7 @@ use scrypt::{ use serde::{Deserialize, Serialize}; use std::{borrow::Cow, convert::TryFrom, sync::Arc}; use surf::{StatusCode, Url}; -use tide::{log, Server}; +use tide::{log, Error, Server}; use zrpc::auth as zed_auth; static CURRENT_GITHUB_USER: &'static str = "current_github_user"; @@ -33,51 +33,48 @@ pub struct User { pub is_admin: bool, } -pub struct VerifyToken; - -#[async_trait] -impl tide::Middleware> for VerifyToken { - async fn handle( - &self, - mut request: Request, - next: tide::Next<'_, Arc>, - ) -> tide::Result { - let mut auth_header = request - .header("Authorization") - .ok_or_else(|| anyhow!("no authorization header"))? - .last() - .as_str() - .split_whitespace(); - - let user_id = UserId( - auth_header - .next() - .ok_or_else(|| anyhow!("missing user id in authorization header"))? - .parse()?, - ); - let access_token = auth_header - .next() - .ok_or_else(|| anyhow!("missing access token in authorization header"))?; - - let state = request.state().clone(); - - let mut credentials_valid = false; - for password_hash in state.db.get_access_token_hashes(user_id).await? { - if verify_access_token(&access_token, &password_hash)? { - credentials_valid = true; - break; - } +pub async fn process_auth_header(request: &Request) -> tide::Result { + let mut auth_header = request + .header("Authorization") + .ok_or_else(|| { + Error::new( + StatusCode::BadRequest, + anyhow!("missing authorization header"), + ) + })? + .last() + .as_str() + .split_whitespace(); + let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| { + Error::new( + StatusCode::BadRequest, + anyhow!("missing user id in authorization header"), + ) + })?); + let access_token = auth_header.next().ok_or_else(|| { + Error::new( + StatusCode::BadRequest, + anyhow!("missing access token in authorization header"), + ) + })?; + + let state = request.state().clone(); + let mut credentials_valid = false; + for password_hash in state.db.get_access_token_hashes(user_id).await? { + if verify_access_token(&access_token, &password_hash)? { + credentials_valid = true; + break; } + } - if credentials_valid { - request.set_ext(user_id); - Ok(next.run(request).await) - } else { - let mut response = tide::Response::new(StatusCode::Unauthorized); - response.set_body("invalid credentials"); - Ok(response) - } + if !credentials_valid { + Err(Error::new( + StatusCode::Unauthorized, + anyhow!("invalid credentials"), + ))?; } + + Ok(user_id) } #[async_trait] diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 1d8de2d9a3a0c0ba1c8f8b74f90cf9871c754a62..a9ffdad8997d8e2338223b0afdfdd05e3981fd1b 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1,7 +1,7 @@ mod store; use super::{ - auth, + auth::process_auth_header, db::{ChannelId, MessageId, UserId}, AppState, }; @@ -885,8 +885,7 @@ where pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let server = Server::new(app.state().clone(), rpc.clone(), None); - app.at("/rpc").with(auth::VerifyToken).get(move |request: Request>| { - let user_id = request.ext::().copied(); + app.at("/rpc").get(move |request: Request>| { let server = server.clone(); async move { const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -907,6 +906,8 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { None => return Err(anyhow!("expected sec-websocket-key"))?, }; + let user_id = process_auth_header(&request).await?; + let mut response = Response::new(StatusCode::SwitchingProtocols); response.insert_header(UPGRADE, "websocket"); response.insert_header(CONNECTION, "Upgrade"); @@ -917,10 +918,17 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let http_res: &mut tide::http::Response = response.as_mut(); let upgrade_receiver = http_res.recv_upgrade().await; let addr = request.remote().unwrap_or("unknown").to_string(); - let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?; task::spawn(async move { if let Some(stream) = upgrade_receiver.await { - server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await; + server + .handle_connection( + Connection::new( + WebSocketStream::from_raw_socket(stream, Role::Server, None).await, + ), + addr, + user_id, + ) + .await; } });