Enable descriptive HTTP errors to be returned from DB layer

Nathan Sobo and Antonio Scandurra created

For now, we only use this when redeeming an invite code.

Co-Authored-By: Antonio Scandurra <me@as-cii.com>

Change summary

Cargo.lock                |  1 +
crates/collab/Cargo.toml  |  1 +
crates/collab/src/auth.rs |  5 +++--
crates/collab/src/db.rs   | 33 ++++++++++++++++++++++++++++-----
crates/collab/src/main.rs | 12 ++++++++++++
5 files changed, 45 insertions(+), 7 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -851,6 +851,7 @@ dependencies = [
  "envy",
  "futures",
  "gpui",
+ "hyper",
  "language",
  "lazy_static",
  "lipsum",

crates/collab/Cargo.toml 🔗

@@ -25,6 +25,7 @@ base64 = "0.13"
 clap = { version = "3.1", features = ["derive"], optional = true }
 envy = "0.4.2"
 futures = "0.3"
+hyper = "0.14"
 lazy_static = "1.4"
 lipsum = { version = "0.8", optional = true }
 nanoid = "0.4"

crates/collab/src/auth.rs 🔗

@@ -91,7 +91,8 @@ fn hash_access_token(token: &str) -> Result<String> {
             None,
             params,
             &SaltString::generate(thread_rng()),
-        )?
+        )
+        .map_err(anyhow::Error::new)?
         .to_string())
 }
 
@@ -105,6 +106,6 @@ pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<St
 }
 
 pub fn verify_access_token(token: &str, hash: &str) -> Result<bool> {
-    let hash = PasswordHash::new(hash)?;
+    let hash = PasswordHash::new(hash).map_err(anyhow::Error::new)?;
     Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
 }

crates/collab/src/db.rs 🔗

@@ -1,6 +1,7 @@
-use crate::Result;
+use crate::{Error, Result};
 use anyhow::{anyhow, Context};
 use async_trait::async_trait;
+use axum::http::StatusCode;
 use futures::StreamExt;
 use nanoid::nanoid;
 use serde::Serialize;
@@ -237,7 +238,7 @@ impl Db for PostgresDb {
         .fetch_optional(&self.pool)
         .await?;
         if let Some((code, count)) = result {
-            Ok(Some((code, count.try_into()?)))
+            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
         } else {
             Ok(None)
         }
@@ -246,7 +247,7 @@ impl Db for PostgresDb {
     async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId> {
         let mut tx = self.pool.begin().await?;
 
-        let inviter_id: UserId = sqlx::query_scalar(
+        let inviter_id: Option<UserId> = sqlx::query_scalar(
             "
                 UPDATE users
                 SET invite_count = invite_count - 1
@@ -258,8 +259,30 @@ impl Db for PostgresDb {
         )
         .bind(code)
         .fetch_optional(&mut tx)
-        .await?
-        .ok_or_else(|| anyhow!("invite code not found"))?;
+        .await?;
+
+        let inviter_id = match inviter_id {
+            Some(inviter_id) => inviter_id,
+            None => {
+                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
+                    .bind(code)
+                    .fetch_optional(&mut tx)
+                    .await?
+                    .is_some()
+                {
+                    Err(Error::Http(
+                        StatusCode::UNAUTHORIZED,
+                        "no invites remaining".to_string(),
+                    ))?
+                } else {
+                    Err(Error::Http(
+                        StatusCode::NOT_FOUND,
+                        "invite code not found".to_string(),
+                    ))?
+                }
+            }
+        };
+
         let invitee_id = sqlx::query_scalar(
             "
                 INSERT INTO users

crates/collab/src/main.rs 🔗

@@ -88,6 +88,18 @@ impl From<sqlx::Error> for Error {
     }
 }
 
+impl From<axum::Error> for Error {
+    fn from(error: axum::Error) -> Self {
+        Self::Internal(error.into())
+    }
+}
+
+impl From<hyper::Error> for Error {
+    fn from(error: hyper::Error) -> Self {
+        Self::Internal(error.into())
+    }
+}
+
 impl IntoResponse for Error {
     fn into_response(self) -> axum::response::Response {
         match self {