WIP

Nathan Sobo created

Change summary

crates/collab/src/api.rs  | 79 ++++++++++++++++++++---------------
crates/collab/src/auth.rs | 90 ++++++++++++++++++++++------------------
crates/collab/src/main.rs | 23 ++--------
3 files changed, 99 insertions(+), 93 deletions(-)

Detailed changes

crates/collab/src/api.rs 🔗

@@ -7,42 +7,45 @@ use anyhow::anyhow;
 use axum::{
     body::Body,
     extract::{Path, Query},
-    http::StatusCode,
-    routing::{delete, get, post, put},
-    Json, Router,
+    http::{self, Request, StatusCode},
+    middleware::{self, Next},
+    response::IntoResponse,
+    routing::{get, post, put},
+    Extension, Json, Router,
 };
 use serde::{Deserialize, Serialize};
 use std::sync::Arc;
+use tower::ServiceBuilder;
+
+pub fn routes(state: Arc<AppState>) -> Router<Body> {
+    Router::new()
+        .route("/users", get(get_users).post(create_user))
+        .route("/users/:id", put(update_user).delete(destroy_user))
+        .route("/users/:gh_login", get(get_user))
+        .route("/users/:gh_login/access_tokens", post(create_access_token))
+        .layer(
+            ServiceBuilder::new()
+                .layer(Extension(state))
+                .layer(middleware::from_fn(validate_api_token)),
+        )
+}
 
-pub fn add_routes(router: Router<Body>, app: Arc<AppState>) -> Router<Body> {
-    router
-        .route("/users", {
-            let app = app.clone();
-            get(move || get_users(app))
-        })
-        .route("/users", {
-            let app = app.clone();
-            post(move |params| create_user(params, app))
-        })
-        .route("/users/:id", {
-            let app = app.clone();
-            put(move |user_id, params| update_user(user_id, params, app))
-        })
-        .route("/users/:id", {
-            let app = app.clone();
-            delete(move |user_id| destroy_user(user_id, app))
-        })
-        .route("/users/:github_login", {
-            let app = app.clone();
-            get(move |github_login| get_user(github_login, app))
-        })
-        .route("/users/:github_login/access_tokens", {
-            let app = app.clone();
-            post(move |github_login, params| create_access_token(github_login, params, app))
-        })
+pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
+    let mut auth_header = req
+        .headers()
+        .get(http::header::AUTHORIZATION)
+        .and_then(|header| header.to_str().ok())
+        .ok_or_else(|| {
+            Error::Http(
+                StatusCode::BAD_REQUEST,
+                "missing authorization header".to_string(),
+            )
+        })?;
+
+    Ok::<_, Error>(next.run(req).await)
 }
 
-async fn get_users(app: Arc<AppState>) -> Result<Json<Vec<User>>> {
+async fn get_users(Extension(app): Extension<Arc<AppState>>) -> Result<Json<Vec<User>>> {
     let users = app.db.get_all_users().await?;
     Ok(Json(users))
 }
@@ -55,7 +58,7 @@ struct CreateUserParams {
 
 async fn create_user(
     Json(params): Json<CreateUserParams>,
-    app: Arc<AppState>,
+    Extension(app): Extension<Arc<AppState>>,
 ) -> Result<Json<User>> {
     let user_id = app
         .db
@@ -79,7 +82,7 @@ struct UpdateUserParams {
 async fn update_user(
     Path(user_id): Path<i32>,
     Json(params): Json<UpdateUserParams>,
-    app: Arc<AppState>,
+    Extension(app): Extension<Arc<AppState>>,
 ) -> Result<()> {
     app.db
         .set_user_is_admin(UserId(user_id), params.admin)
@@ -87,12 +90,18 @@ async fn update_user(
     Ok(())
 }
 
-async fn destroy_user(Path(user_id): Path<i32>, app: Arc<AppState>) -> Result<()> {
+async fn destroy_user(
+    Path(user_id): Path<i32>,
+    Extension(app): Extension<Arc<AppState>>,
+) -> Result<()> {
     app.db.destroy_user(UserId(user_id)).await?;
     Ok(())
 }
 
-async fn get_user(Path(login): Path<String>, app: Arc<AppState>) -> Result<Json<User>> {
+async fn get_user(
+    Path(login): Path<String>,
+    Extension(app): Extension<Arc<AppState>>,
+) -> Result<Json<User>> {
     let user = app
         .db
         .get_user_by_github_login(&login)
@@ -116,7 +125,7 @@ struct CreateAccessTokenResponse {
 async fn create_access_token(
     Path(login): Path<String>,
     Query(params): Query<CreateAccessTokenQueryParams>,
-    app: Arc<AppState>,
+    Extension(app): Extension<Arc<AppState>>,
 ) -> Result<Json<CreateAccessTokenResponse>> {
     //     request.require_token().await?;
 

crates/collab/src/auth.rs 🔗

@@ -1,54 +1,64 @@
+use std::sync::Arc;
+
 use super::db::{self, UserId};
+use crate::{AppState, Error};
 use anyhow::{Context, Result};
+use axum::{
+    http::{self, Request, StatusCode},
+    middleware::Next,
+    response::IntoResponse,
+};
 use rand::thread_rng;
 use scrypt::{
     password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
     Scrypt,
 };
 
-// pub async fn process_auth_header(request: &Request) -> tide::Result<UserId> {
-//     let mut auth_header = request
-//         .header("Authorization")
-//         .ok_or_else(|| {
-//             Error::new(
-//                 StatusCode::BadRequest,
-//                 anyhow!("missing authorization header"),
-//             )
-//         })?
-//         .last()
-//         .as_str()
-//         .split_whitespace();
-//     let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
-//         Error::new(
-//             StatusCode::BadRequest,
-//             anyhow!("missing user id in authorization header"),
-//         )
-//     })?);
-//     let access_token = auth_header.next().ok_or_else(|| {
-//         Error::new(
-//             StatusCode::BadRequest,
-//             anyhow!("missing access token in authorization header"),
-//         )
-//     })?;
+pub async fn validate_header<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
+    let mut auth_header = req
+        .headers()
+        .get(http::header::AUTHORIZATION)
+        .and_then(|header| header.to_str().ok())
+        .ok_or_else(|| {
+            Error::Http(
+                StatusCode::BAD_REQUEST,
+                "missing authorization header".to_string(),
+            )
+        })?
+        .split_whitespace();
+
+    let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
+        Error::Http(
+            StatusCode::BAD_REQUEST,
+            "missing user id in authorization header".to_string(),
+        )
+    })?);
 
-//     let state = request.state().clone();
-//     let mut credentials_valid = false;
-//     for password_hash in state.db.get_access_token_hashes(user_id).await? {
-//         if verify_access_token(&access_token, &password_hash)? {
-//             credentials_valid = true;
-//             break;
-//         }
-//     }
+    let access_token = auth_header.next().ok_or_else(|| {
+        Error::Http(
+            StatusCode::BAD_REQUEST,
+            "missing access token in authorization header".to_string(),
+        )
+    })?;
 
-//     if !credentials_valid {
-//         Err(Error::new(
-//             StatusCode::Unauthorized,
-//             anyhow!("invalid credentials"),
-//         ))?;
-//     }
+    let state = req.extensions().get::<Arc<AppState>>().unwrap();
+    let mut credentials_valid = false;
+    for password_hash in state.db.get_access_token_hashes(user_id).await? {
+        if verify_access_token(&access_token, &password_hash)? {
+            credentials_valid = true;
+            break;
+        }
+    }
 
-//     Ok(user_id)
-// }
+    if !credentials_valid {
+        Err(Error::Http(
+            StatusCode::UNAUTHORIZED,
+            "invalid credentials".to_string(),
+        ))?;
+    }
+
+    Ok::<_, Error>(next.run(req).await)
+}
 
 const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
 

crates/collab/src/main.rs 🔗

@@ -11,8 +11,6 @@ use db::{Db, PostgresDb};
 use serde::Deserialize;
 use std::{net::TcpListener, sync::Arc};
 
-// type Request = tide::Request<Arc<AppState>>;
-
 #[derive(Default, Deserialize)]
 pub struct Config {
     pub http_port: u16,
@@ -22,31 +20,20 @@ pub struct Config {
 
 pub struct AppState {
     db: Arc<dyn Db>,
-    config: Config,
+    api_token: String,
 }
 
 impl AppState {
     async fn new(config: Config) -> Result<Arc<Self>> {
         let db = PostgresDb::new(&config.database_url, 5).await?;
-
         let this = Self {
             db: Arc::new(db),
-            config,
+            api_token: config.api_token.clone(),
         };
         Ok(Arc::new(this))
     }
 }
 
-// trait RequestExt {
-//     fn db(&self) -> &Arc<dyn Db>;
-// }
-
-// impl RequestExt for Request<Body> {
-//     fn db(&self) -> &Arc<dyn Db> {
-//         &self.data::<Arc<AppState>>().unwrap().db
-//     }
-// }
-
 #[tokio::main]
 async fn main() -> Result<()> {
     if std::env::var("LOG_JSON").is_ok() {
@@ -68,7 +55,7 @@ async fn main() -> Result<()> {
     run_server(
         state.clone(),
         rpc,
-        TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
+        TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
             .expect("failed to bind TCP listener"),
     )
     .await?;
@@ -80,11 +67,11 @@ pub async fn run_server(
     peer: Arc<Peer>,
     listener: TcpListener,
 ) -> Result<()> {
-    let app = Router::<Body>::new();
     // TODO: Compression on API routes?
     // TODO: Authenticate API routes.
 
-    let app = api::add_routes(app, state);
+    let app = Router::<Body>::new().merge(api::routes(state.clone()));
+
     // TODO: Add rpc routes
 
     axum::Server::from_tcp(listener)?