Finish adding API routes

Nathan Sobo and Max Brunsfeld created

We haven't tested them yet.

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

crates/collab/src/api.rs  | 201 +++++++++++++++++++---------------------
crates/collab/src/auth.rs |  96 ++++++++----------
crates/collab/src/main.rs |  26 ++++-
3 files changed, 159 insertions(+), 164 deletions(-)

Detailed changes

crates/collab/src/api.rs 🔗

@@ -1,59 +1,62 @@
 use crate::{
-    db::{Db, User, UserId},
-    AppState, Result,
+    auth,
+    db::{User, UserId},
+    AppState, Error, Result,
 };
 use anyhow::anyhow;
 use axum::{
     body::Body,
-    extract::Path,
-    http::{Request, StatusCode},
-    response::{IntoResponse, Response},
-    routing::{get, put},
+    extract::{Path, Query},
+    http::StatusCode,
+    routing::{delete, get, post, put},
     Json, Router,
 };
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
 use std::sync::Arc;
 
 pub fn add_routes(router: Router<Body>, app: Arc<AppState>) -> Router<Body> {
     router
         .route("/users", {
             let app = app.clone();
-            get(move |req| get_users(req, app))
+            get(move || get_users(app))
         })
         .route("/users", {
             let app = app.clone();
-            get(move |params| create_user(params, app))
+            post(move |params| create_user(params, app))
         })
         .route("/users/:id", {
             let app = app.clone();
             put(move |user_id, params| update_user(user_id, params, app))
         })
+        .route("/users/:id", {
+            let app = app.clone();
+            delete(move |user_id| destroy_user(user_id, app))
+        })
+        .route("/users/:github_login", {
+            let app = app.clone();
+            get(move |github_login| get_user(github_login, app))
+        })
+        .route("/users/:github_login/access_tokens", {
+            let app = app.clone();
+            post(move |github_login, params| create_access_token(github_login, params, app))
+        })
 }
 
-// pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
-//     app.at("/users").get(get_users);
-//     app.at("/users").post(create_user);
-//     app.at("/users/:id").put(update_user);
-//     app.at("/users/:id").delete(destroy_user);
-//     app.at("/users/:github_login").get(get_user);
-//     app.at("/users/:github_login/access_tokens")
-//         .post(create_access_token);
-// }
-
-async fn get_users(request: Request<Body>, app: Arc<AppState>) -> Result<Json<Vec<User>>> {
-    // request.require_token().await?;
-
+async fn get_users(app: Arc<AppState>) -> Result<Json<Vec<User>>> {
     let users = app.db.get_all_users().await?;
     Ok(Json(users))
 }
 
 #[derive(Deserialize)]
-struct CreateUser {
+struct CreateUserParams {
     github_login: String,
     admin: bool,
 }
 
-async fn create_user(Json(params): Json<CreateUser>, app: Arc<AppState>) -> Result<Json<User>> {
+async fn create_user(
+    Json(params): Json<CreateUserParams>,
+    app: Arc<AppState>,
+) -> Result<Json<User>> {
     let user_id = app
         .db
         .create_user(&params.github_login, params.admin)
@@ -69,102 +72,88 @@ async fn create_user(Json(params): Json<CreateUser>, app: Arc<AppState>) -> Resu
 }
 
 #[derive(Deserialize)]
-struct UpdateUser {
+struct UpdateUserParams {
     admin: bool,
 }
 
 async fn update_user(
     Path(user_id): Path<i32>,
-    Json(params): Json<UpdateUser>,
+    Json(params): Json<UpdateUserParams>,
     app: Arc<AppState>,
-) -> Result<impl IntoResponse> {
-    let user_id = UserId(user_id);
-    app.db.set_user_is_admin(user_id, params.admin).await?;
+) -> Result<()> {
+    app.db
+        .set_user_is_admin(UserId(user_id), params.admin)
+        .await?;
     Ok(())
 }
 
-// async fn update_user(mut request: Request) -> tide::Result {
-//     request.require_token().await?;
-
-//     #[derive(Deserialize)]
-//     struct Params {
-//         admin: bool,
-//     }
-
-//     request
-//         .db()
-//         .set_user_is_admin(user_id, params.admin)
-//         .await?;
-
-//     Ok(tide::Response::builder(StatusCode::Ok).build())
-// }
-
-// async fn destroy_user(request: Request) -> tide::Result {
-//     request.require_token().await?;
-//     let user_id = UserId(
-//         request
-//             .param("id")?
-//             .parse::<i32>()
-//             .map_err(|error| surf::Error::from_str(StatusCode::BadRequest, error.to_string()))?,
-//     );
-
-//     request.db().destroy_user(user_id).await?;
-
-//     Ok(tide::Response::builder(StatusCode::Ok).build())
-// }
-
-// async fn create_access_token(request: Request) -> tide::Result {
-//     request.require_token().await?;
+async fn destroy_user(Path(user_id): Path<i32>, app: Arc<AppState>) -> Result<()> {
+    app.db.destroy_user(UserId(user_id)).await?;
+    Ok(())
+}
 
-//     let user = request
-//         .db()
-//         .get_user_by_github_login(request.param("github_login")?)
-//         .await?
-//         .ok_or_else(|| surf::Error::from_str(StatusCode::NotFound, "user not found"))?;
+async fn get_user(Path(login): Path<String>, app: Arc<AppState>) -> Result<Json<User>> {
+    let user = app
+        .db
+        .get_user_by_github_login(&login)
+        .await?
+        .ok_or_else(|| anyhow!("user not found"))?;
+    Ok(Json(user))
+}
 
-//     #[derive(Deserialize)]
-//     struct QueryParams {
-//         public_key: String,
-//         impersonate: Option<String>,
-//     }
+#[derive(Deserialize)]
+struct CreateAccessTokenQueryParams {
+    public_key: String,
+    impersonate: Option<String>,
+}
 
-//     let query_params: QueryParams = request.query().map_err(|_| {
-//         surf::Error::from_str(StatusCode::UnprocessableEntity, "invalid query params")
-//     })?;
-
-//     let mut user_id = user.id;
-//     if let Some(impersonate) = query_params.impersonate {
-//         if user.admin {
-//             if let Some(impersonated_user) =
-//                 request.db().get_user_by_github_login(&impersonate).await?
-//             {
-//                 user_id = impersonated_user.id;
-//             } else {
-//                 return Ok(tide::Response::builder(StatusCode::UnprocessableEntity)
-//                     .body(format!(
-//                         "Can't impersonate non-existent user {}",
-//                         impersonate
-//                     ))
-//                     .build());
-//             }
-//         } else {
-//             return Ok(tide::Response::builder(StatusCode::Unauthorized)
-//                 .body(format!(
-//                     "Can't impersonate user {} because the real user isn't an admin",
-//                     impersonate
-//                 ))
-//                 .build());
-//         }
-//     }
+#[derive(Serialize)]
+struct CreateAccessTokenResponse {
+    user_id: UserId,
+    encrypted_access_token: String,
+}
 
-//     let access_token = auth::create_access_token(request.db().as_ref(), user_id).await?;
-//     let encrypted_access_token =
-//         auth::encrypt_access_token(&access_token, query_params.public_key.clone())?;
+async fn create_access_token(
+    Path(login): Path<String>,
+    Query(params): Query<CreateAccessTokenQueryParams>,
+    app: Arc<AppState>,
+) -> Result<Json<CreateAccessTokenResponse>> {
+    //     request.require_token().await?;
 
-//     Ok(tide::Response::builder(StatusCode::Ok)
-//         .body(json!({"user_id": user_id, "encrypted_access_token": encrypted_access_token}))
-//         .build())
-// }
+    let user = app
+        .db
+        .get_user_by_github_login(&login)
+        .await?
+        .ok_or_else(|| anyhow!("user not found"))?;
+
+    let mut user_id = user.id;
+    if let Some(impersonate) = params.impersonate {
+        if user.admin {
+            if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
+                user_id = impersonated_user.id;
+            } else {
+                return Err(Error::Http(
+                    StatusCode::UNPROCESSABLE_ENTITY,
+                    format!("user {impersonate} does not exist"),
+                ));
+            }
+        } else {
+            return Err(Error::Http(
+                StatusCode::UNAUTHORIZED,
+                format!("you do not have permission to impersonate other users"),
+            ));
+        }
+    }
+
+    let access_token = auth::create_access_token(app.db.as_ref(), user_id).await?;
+    let encrypted_access_token =
+        auth::encrypt_access_token(&access_token, params.public_key.clone())?;
+
+    Ok(Json(CreateAccessTokenResponse {
+        user_id,
+        encrypted_access_token,
+    }))
+}
 
 // #[async_trait]
 // pub trait RequestExt {

crates/collab/src/auth.rs 🔗

@@ -1,18 +1,10 @@
-// use super::{
-//     db::{self, UserId},
-//     errors::TideResultExt,
-// };
-// use crate::Request;
-// use anyhow::{anyhow, Context};
-// use rand::thread_rng;
-// use rpc::auth as zed_auth;
-// use scrypt::{
-//     password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
-//     Scrypt,
-// };
-// use std::convert::TryFrom;
-// use surf::StatusCode;
-// use tide::Error;
+use super::db::{self, UserId};
+use anyhow::{Context, Result};
+use rand::thread_rng;
+use scrypt::{
+    password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
+    Scrypt,
+};
 
 // pub async fn process_auth_header(request: &Request) -> tide::Result<UserId> {
 //     let mut auth_header = request
@@ -58,45 +50,45 @@
 //     Ok(user_id)
 // }
 
-// const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
+const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
 
-// pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> tide::Result<String> {
-//     let access_token = zed_auth::random_token();
-//     let access_token_hash =
-//         hash_access_token(&access_token).context("failed to hash access token")?;
-//     db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
-//         .await?;
-//     Ok(access_token)
-// }
+pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> Result<String> {
+    let access_token = rpc::auth::random_token();
+    let access_token_hash =
+        hash_access_token(&access_token).context("failed to hash access token")?;
+    db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
+        .await?;
+    Ok(access_token)
+}
 
-// fn hash_access_token(token: &str) -> tide::Result<String> {
-//     // Avoid slow hashing in debug mode.
-//     let params = if cfg!(debug_assertions) {
-//         scrypt::Params::new(1, 1, 1).unwrap()
-//     } else {
-//         scrypt::Params::recommended()
-//     };
+fn hash_access_token(token: &str) -> Result<String> {
+    // Avoid slow hashing in debug mode.
+    let params = if cfg!(debug_assertions) {
+        scrypt::Params::new(1, 1, 1).unwrap()
+    } else {
+        scrypt::Params::recommended()
+    };
 
-//     Ok(Scrypt
-//         .hash_password(
-//             token.as_bytes(),
-//             None,
-//             params,
-//             &SaltString::generate(thread_rng()),
-//         )?
-//         .to_string())
-// }
+    Ok(Scrypt
+        .hash_password(
+            token.as_bytes(),
+            None,
+            params,
+            &SaltString::generate(thread_rng()),
+        )?
+        .to_string())
+}
 
-// pub fn encrypt_access_token(access_token: &str, public_key: String) -> tide::Result<String> {
-//     let native_app_public_key =
-//         zed_auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
-//     let encrypted_access_token = native_app_public_key
-//         .encrypt_string(&access_token)
-//         .context("failed to encrypt access token with public key")?;
-//     Ok(encrypted_access_token)
-// }
+pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
+    let native_app_public_key =
+        rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
+    let encrypted_access_token = native_app_public_key
+        .encrypt_string(&access_token)
+        .context("failed to encrypt access token with public key")?;
+    Ok(encrypted_access_token)
+}
 
-// pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
-//     let hash = PasswordHash::new(hash)?;
-//     Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
-// }
+pub fn verify_access_token(token: &str, hash: &str) -> Result<bool> {
+    let hash = PasswordHash::new(hash)?;
+    Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
+}

crates/collab/src/main.rs 🔗

@@ -94,33 +94,47 @@ pub async fn run_server(
     Ok(())
 }
 
-type Result<T> = std::result::Result<T, Error>;
+pub type Result<T> = std::result::Result<T, Error>;
 
-struct Error(anyhow::Error);
+pub enum Error {
+    Http(StatusCode, String),
+    Internal(anyhow::Error),
+}
 
 impl<E> From<E> for Error
 where
     E: Into<anyhow::Error>,
 {
     fn from(error: E) -> Self {
-        Self(error.into())
+        Self::Internal(error.into())
     }
 }
 
 impl IntoResponse for Error {
     fn into_response(self) -> axum::response::Response {
-        (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &self.0)).into_response()
+        match self {
+            Error::Http(code, message) => (code, message).into_response(),
+            Error::Internal(error) => {
+                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
+            }
+        }
     }
 }
 
 impl std::fmt::Debug for Error {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        self.0.fmt(f)
+        match self {
+            Error::Http(code, message) => (code, message).fmt(f),
+            Error::Internal(error) => error.fmt(f),
+        }
     }
 }
 
 impl std::fmt::Display for Error {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        self.0.fmt(f)
+        match self {
+            Error::Http(code, message) => write!(f, "{code}: {message}"),
+            Error::Internal(error) => error.fmt(f),
+        }
     }
 }