@@ -18,7 +18,7 @@ use scrypt::{
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, convert::TryFrom, sync::Arc};
use surf::{StatusCode, Url};
-use tide::{log, Server};
+use tide::{log, Error, Server};
use zrpc::auth as zed_auth;
static CURRENT_GITHUB_USER: &'static str = "current_github_user";
@@ -33,51 +33,48 @@ pub struct User {
pub is_admin: bool,
}
-pub struct VerifyToken;
-
-#[async_trait]
-impl tide::Middleware<Arc<AppState>> for VerifyToken {
- async fn handle(
- &self,
- mut request: Request,
- next: tide::Next<'_, Arc<AppState>>,
- ) -> tide::Result {
- let mut auth_header = request
- .header("Authorization")
- .ok_or_else(|| anyhow!("no authorization header"))?
- .last()
- .as_str()
- .split_whitespace();
-
- let user_id = UserId(
- auth_header
- .next()
- .ok_or_else(|| anyhow!("missing user id in authorization header"))?
- .parse()?,
- );
- let access_token = auth_header
- .next()
- .ok_or_else(|| anyhow!("missing access token in authorization header"))?;
-
- let state = request.state().clone();
-
- let mut credentials_valid = false;
- for password_hash in state.db.get_access_token_hashes(user_id).await? {
- if verify_access_token(&access_token, &password_hash)? {
- credentials_valid = true;
- break;
- }
+pub async fn process_auth_header(request: &Request) -> tide::Result<UserId> {
+ let mut auth_header = request
+ .header("Authorization")
+ .ok_or_else(|| {
+ Error::new(
+ StatusCode::BadRequest,
+ anyhow!("missing authorization header"),
+ )
+ })?
+ .last()
+ .as_str()
+ .split_whitespace();
+ let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
+ Error::new(
+ StatusCode::BadRequest,
+ anyhow!("missing user id in authorization header"),
+ )
+ })?);
+ let access_token = auth_header.next().ok_or_else(|| {
+ Error::new(
+ StatusCode::BadRequest,
+ anyhow!("missing access token in authorization header"),
+ )
+ })?;
+
+ let state = request.state().clone();
+ let mut credentials_valid = false;
+ for password_hash in state.db.get_access_token_hashes(user_id).await? {
+ if verify_access_token(&access_token, &password_hash)? {
+ credentials_valid = true;
+ break;
}
+ }
- if credentials_valid {
- request.set_ext(user_id);
- Ok(next.run(request).await)
- } else {
- let mut response = tide::Response::new(StatusCode::Unauthorized);
- response.set_body("invalid credentials");
- Ok(response)
- }
+ if !credentials_valid {
+ Err(Error::new(
+ StatusCode::Unauthorized,
+ anyhow!("invalid credentials"),
+ ))?;
}
+
+ Ok(user_id)
}
#[async_trait]
@@ -1,7 +1,7 @@
mod store;
use super::{
- auth,
+ auth::process_auth_header,
db::{ChannelId, MessageId, UserId},
AppState,
};
@@ -885,8 +885,7 @@ where
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let server = Server::new(app.state().clone(), rpc.clone(), None);
- app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
- let user_id = request.ext::<UserId>().copied();
+ app.at("/rpc").get(move |request: Request<Arc<AppState>>| {
let server = server.clone();
async move {
const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@@ -907,6 +906,8 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
None => return Err(anyhow!("expected sec-websocket-key"))?,
};
+ let user_id = process_auth_header(&request).await?;
+
let mut response = Response::new(StatusCode::SwitchingProtocols);
response.insert_header(UPGRADE, "websocket");
response.insert_header(CONNECTION, "Upgrade");
@@ -917,10 +918,17 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let http_res: &mut tide::http::Response = response.as_mut();
let upgrade_receiver = http_res.recv_upgrade().await;
let addr = request.remote().unwrap_or("unknown").to_string();
- let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
- server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
+ server
+ .handle_connection(
+ Connection::new(
+ WebSocketStream::from_raw_socket(stream, Role::Server, None).await,
+ ),
+ addr,
+ user_id,
+ )
+ .await;
}
});