Add ability to get the user for an invite code in collab API

Nathan Sobo and Antonio Scandurra created

Co-Authored-By: Antonio Scandurra <me@as-cii.com>

Change summary

crates/collab/src/api.rs |  8 +++++++
crates/collab/src/db.rs  | 46 ++++++++++++++++++++++++++++++++---------
2 files changed, 44 insertions(+), 10 deletions(-)

Detailed changes

crates/collab/src/api.rs 🔗

@@ -26,6 +26,7 @@ pub fn routes(state: Arc<AppState>) -> Router<Body> {
             put(update_user).delete(destroy_user).get(get_user),
         )
         .route("/users/:id/access_tokens", post(create_access_token))
+        .route("/invite_codes/:code", get(get_user_for_invite_code))
         .route("/panic", post(trace_panic))
         .layer(
             ServiceBuilder::new()
@@ -210,3 +211,10 @@ async fn create_access_token(
         encrypted_access_token,
     }))
 }
+
+async fn get_user_for_invite_code(
+    Path(code): Path<String>,
+    Extension(app): Extension<Arc<AppState>>,
+) -> Result<Json<User>> {
+    Ok(Json(app.db.get_user_for_invite_code(&code).await?))
+}

crates/collab/src/db.rs 🔗

@@ -21,7 +21,8 @@ pub trait Db: Send + Sync {
     async fn destroy_user(&self, id: UserId) -> Result<()>;
 
     async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
-    async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>>;
+    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
+    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
     async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId>;
 
     async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
@@ -226,7 +227,7 @@ impl Db for PostgresDb {
         Ok(())
     }
 
-    async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>> {
+    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
         let result: Option<(String, i32)> = sqlx::query_as(
             "
                 SELECT invite_code, invite_count
@@ -244,6 +245,25 @@ impl Db for PostgresDb {
         }
     }
 
+    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
+        sqlx::query_as(
+            "
+                SELECT *
+                FROM users
+                WHERE invite_code = $1
+            ",
+        )
+        .bind(code)
+        .fetch_optional(&self.pool)
+        .await?
+        .ok_or_else(|| {
+            Error::Http(
+                StatusCode::NOT_FOUND,
+                "that invite code does not exist".to_string(),
+            )
+        })
+    }
+
     async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId> {
         let mut tx = self.pool.begin().await?;
 
@@ -1337,16 +1357,17 @@ pub mod tests {
         let user1 = db.create_user("user-1", false).await.unwrap();
 
         // Initially, user 1 has no invite code
-        assert_eq!(db.get_invite_code(user1).await.unwrap(), None);
+        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
 
         // User 1 creates an invite code that can be used twice.
         db.set_invite_count(user1, 2).await.unwrap();
-        let (invite_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        let (invite_code, invite_count) =
+            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
         assert_eq!(invite_count, 2);
 
         // User 2 redeems the invite code and becomes a contact of user 1.
         let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap();
-        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
         assert_eq!(invite_count, 1);
         assert_eq!(
             db.get_contacts(user1).await.unwrap(),
@@ -1377,7 +1398,7 @@ pub mod tests {
 
         // User 3 redeems the invite code and becomes a contact of user 1.
         let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap();
-        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
         assert_eq!(invite_count, 0);
         assert_eq!(
             db.get_contacts(user1).await.unwrap(),
@@ -1417,13 +1438,14 @@ pub mod tests {
 
         // Invite count can be updated after the code has been created.
         db.set_invite_count(user1, 2).await.unwrap();
-        let (latest_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        let (latest_code, invite_count) =
+            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
         assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
         assert_eq!(invite_count, 2);
 
         // User 4 can now redeem the invite code and becomes a contact of user 1.
         let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap();
-        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
         assert_eq!(invite_count, 1);
         assert_eq!(
             db.get_contacts(user1).await.unwrap(),
@@ -1464,7 +1486,7 @@ pub mod tests {
         db.redeem_invite_code(&invite_code, "user-2")
             .await
             .unwrap_err();
-        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
         assert_eq!(invite_count, 1);
     }
 
@@ -1626,7 +1648,11 @@ pub mod tests {
             unimplemented!()
         }
 
-        async fn get_invite_code(&self, _id: UserId) -> Result<Option<(String, u32)>> {
+        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
+            unimplemented!()
+        }
+
+        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
             unimplemented!()
         }