Add invite codes / counts to users table

Nathan Sobo and Antonio Scandurra created

Co-Authored-By: Antonio Scandurra <me@as-cii.com>

Change summary

Cargo.lock                                                       |  10 
crates/collab/Cargo.toml                                         |   1 
crates/collab/migrations/20220518151305_add_invites_to_users.sql |   7 
crates/collab/src/db.rs                                          | 256 ++
4 files changed, 274 insertions(+)

Detailed changes

Cargo.lock 🔗

@@ -856,6 +856,7 @@ dependencies = [
  "lipsum",
  "log",
  "lsp",
+ "nanoid",
  "opentelemetry",
  "opentelemetry-otlp",
  "parking_lot",
@@ -2761,6 +2762,15 @@ version = "0.8.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
 
+[[package]]
+name = "nanoid"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8"
+dependencies = [
+ "rand 0.8.3",
+]
+
 [[package]]
 name = "native-tls"
 version = "0.2.10"

crates/collab/Cargo.toml 🔗

@@ -27,6 +27,7 @@ envy = "0.4.2"
 futures = "0.3"
 lazy_static = "1.4"
 lipsum = { version = "0.8", optional = true }
+nanoid = "0.4"
 opentelemetry = { version = "0.17", features = ["rt-tokio"] }
 opentelemetry-otlp = { version = "0.10", features = ["tls-roots"] }
 parking_lot = "0.11.1"

crates/collab/migrations/20220518151305_add_invites_to_users.sql 🔗

@@ -0,0 +1,7 @@
+ALTER TABLE users
+ADD invite_code VARCHAR(64),
+ADD invite_count INTEGER NOT NULL DEFAULT 0,
+ADD inviter_id INTEGER REFERENCES users (id),
+ADD created_at TIMESTAMP NOT NULL DEFAULT NOW();
+
+CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code");

crates/collab/src/db.rs 🔗

@@ -1,6 +1,7 @@
 use anyhow::{anyhow, Context, Result};
 use async_trait::async_trait;
 use futures::StreamExt;
+use nanoid::nanoid;
 use serde::Serialize;
 pub use sqlx::postgres::PgPoolOptions as DbOptions;
 use sqlx::{types::Uuid, FromRow};
@@ -17,6 +18,10 @@ pub trait Db: Send + Sync {
     async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
     async fn destroy_user(&self, id: UserId) -> Result<()>;
 
+    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
+    async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>>;
+    async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId>;
+
     async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
     async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
     async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
@@ -189,6 +194,103 @@ impl Db for PostgresDb {
             .map(drop)?)
     }
 
+    // invite codes
+
+    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
+        let mut tx = self.pool.begin().await?;
+        sqlx::query(
+            "
+                UPDATE users
+                SET invite_code = $1
+                WHERE id = $2 AND invite_code IS NULL
+            ",
+        )
+        .bind(nanoid!(16))
+        .bind(id)
+        .execute(&mut tx)
+        .await?;
+        sqlx::query(
+            "
+            UPDATE users
+            SET invite_count = $1
+            WHERE id = $2
+            ",
+        )
+        .bind(count)
+        .bind(id)
+        .execute(&mut tx)
+        .await?;
+        tx.commit().await?;
+        Ok(())
+    }
+
+    async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>> {
+        let result: Option<(String, i32)> = sqlx::query_as(
+            "
+                SELECT invite_code, invite_count
+                FROM users
+                WHERE id = $1 AND invite_code IS NOT NULL 
+            ",
+        )
+        .bind(id)
+        .fetch_optional(&self.pool)
+        .await?;
+        if let Some((code, count)) = result {
+            Ok(Some((code, count.try_into()?)))
+        } else {
+            Ok(None)
+        }
+    }
+
+    async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId> {
+        let mut tx = self.pool.begin().await?;
+
+        let inviter_id: UserId = sqlx::query_scalar(
+            "
+                UPDATE users
+                SET invite_count = invite_count - 1
+                WHERE
+                    invite_code = $1 AND
+                    invite_count > 0
+                RETURNING id
+            ",
+        )
+        .bind(code)
+        .fetch_optional(&mut tx)
+        .await?
+        .ok_or_else(|| anyhow!("invite code not found"))?;
+        let invitee_id = sqlx::query_scalar(
+            "
+                INSERT INTO users
+                    (github_login, admin, inviter_id)
+                VALUES
+                    ($1, 'f', $2)
+                RETURNING id
+            ",
+        )
+        .bind(login)
+        .bind(inviter_id)
+        .fetch_one(&mut tx)
+        .await
+        .map(UserId)?;
+
+        sqlx::query(
+            "
+                INSERT INTO contacts
+                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
+                VALUES
+                    ($1, $2, 't', 't', 't')
+            ",
+        )
+        .bind(inviter_id)
+        .bind(invitee_id)
+        .execute(&mut tx)
+        .await?;
+
+        tx.commit().await?;
+        Ok(invitee_id)
+    }
+
     // contacts
 
     async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
@@ -1198,6 +1300,144 @@ pub mod tests {
         }
     }
 
+    #[tokio::test(flavor = "multi_thread")]
+    async fn test_invite_codes() {
+        let postgres = TestDb::postgres().await;
+        let db = postgres.db();
+        let user1 = db.create_user("user-1", false).await.unwrap();
+
+        // Initially, user 1 has no invite code
+        assert_eq!(db.get_invite_code(user1).await.unwrap(), None);
+
+        // User 1 creates an invite code that can be used twice.
+        db.set_invite_count(user1, 2).await.unwrap();
+        let (invite_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        assert_eq!(invite_count, 2);
+
+        // User 2 redeems the invite code and becomes a contact of user 1.
+        let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap();
+        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        assert_eq!(invite_count, 1);
+        assert_eq!(
+            db.get_contacts(user1).await.unwrap(),
+            [
+                Contact::Accepted {
+                    user_id: user1,
+                    should_notify: false
+                },
+                Contact::Accepted {
+                    user_id: user2,
+                    should_notify: true
+                }
+            ]
+        );
+        assert_eq!(
+            db.get_contacts(user2).await.unwrap(),
+            [
+                Contact::Accepted {
+                    user_id: user1,
+                    should_notify: false
+                },
+                Contact::Accepted {
+                    user_id: user2,
+                    should_notify: false
+                }
+            ]
+        );
+
+        // User 3 redeems the invite code and becomes a contact of user 1.
+        let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap();
+        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        assert_eq!(invite_count, 0);
+        assert_eq!(
+            db.get_contacts(user1).await.unwrap(),
+            [
+                Contact::Accepted {
+                    user_id: user1,
+                    should_notify: false
+                },
+                Contact::Accepted {
+                    user_id: user2,
+                    should_notify: true
+                },
+                Contact::Accepted {
+                    user_id: user3,
+                    should_notify: true
+                }
+            ]
+        );
+        assert_eq!(
+            db.get_contacts(user3).await.unwrap(),
+            [
+                Contact::Accepted {
+                    user_id: user1,
+                    should_notify: false
+                },
+                Contact::Accepted {
+                    user_id: user3,
+                    should_notify: false
+                },
+            ]
+        );
+
+        // Trying to reedem the code for the third time results in an error.
+        db.redeem_invite_code(&invite_code, "user-4")
+            .await
+            .unwrap_err();
+
+        // Invite count can be updated after the code has been created.
+        db.set_invite_count(user1, 2).await.unwrap();
+        let (latest_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
+        assert_eq!(invite_count, 2);
+
+        // User 4 can now redeem the invite code and becomes a contact of user 1.
+        let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap();
+        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        assert_eq!(invite_count, 1);
+        assert_eq!(
+            db.get_contacts(user1).await.unwrap(),
+            [
+                Contact::Accepted {
+                    user_id: user1,
+                    should_notify: false
+                },
+                Contact::Accepted {
+                    user_id: user2,
+                    should_notify: true
+                },
+                Contact::Accepted {
+                    user_id: user3,
+                    should_notify: true
+                },
+                Contact::Accepted {
+                    user_id: user4,
+                    should_notify: true
+                }
+            ]
+        );
+        assert_eq!(
+            db.get_contacts(user4).await.unwrap(),
+            [
+                Contact::Accepted {
+                    user_id: user1,
+                    should_notify: false
+                },
+                Contact::Accepted {
+                    user_id: user4,
+                    should_notify: false
+                },
+            ]
+        );
+
+        // An existing user cannot redeem invite codes.
+        db.redeem_invite_code(&invite_code, "user-2")
+            .await
+            .unwrap_err();
+        let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+        assert_eq!(invite_count, 1);
+    }
+
     pub struct TestDb {
         pub db: Option<Arc<dyn Db>>,
         pub url: String,
@@ -1348,6 +1588,22 @@ pub mod tests {
             unimplemented!()
         }
 
+        // invite codes
+
+        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
+            unimplemented!()
+        }
+
+        async fn get_invite_code(&self, _id: UserId) -> Result<Option<(String, u32)>> {
+            unimplemented!()
+        }
+
+        async fn redeem_invite_code(&self, _code: &str, _login: &str) -> Result<UserId> {
+            unimplemented!()
+        }
+
+        // contacts
+
         async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
             self.background.simulate_random_delay().await;
             let mut contacts = vec![Contact::Accepted {