@@ -1,6 +1,7 @@
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use futures::StreamExt;
+use nanoid::nanoid;
use serde::Serialize;
pub use sqlx::postgres::PgPoolOptions as DbOptions;
use sqlx::{types::Uuid, FromRow};
@@ -17,6 +18,10 @@ pub trait Db: Send + Sync {
async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
async fn destroy_user(&self, id: UserId) -> Result<()>;
+ async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
+ async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>>;
+ async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId>;
+
async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
@@ -189,6 +194,103 @@ impl Db for PostgresDb {
.map(drop)?)
}
+ // invite codes
+
+ async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
+ let mut tx = self.pool.begin().await?;
+ sqlx::query(
+ "
+ UPDATE users
+ SET invite_code = $1
+ WHERE id = $2 AND invite_code IS NULL
+ ",
+ )
+ .bind(nanoid!(16))
+ .bind(id)
+ .execute(&mut tx)
+ .await?;
+ sqlx::query(
+ "
+ UPDATE users
+ SET invite_count = $1
+ WHERE id = $2
+ ",
+ )
+ .bind(count)
+ .bind(id)
+ .execute(&mut tx)
+ .await?;
+ tx.commit().await?;
+ Ok(())
+ }
+
+ async fn get_invite_code(&self, id: UserId) -> Result<Option<(String, u32)>> {
+ let result: Option<(String, i32)> = sqlx::query_as(
+ "
+ SELECT invite_code, invite_count
+ FROM users
+ WHERE id = $1 AND invite_code IS NOT NULL
+ ",
+ )
+ .bind(id)
+ .fetch_optional(&self.pool)
+ .await?;
+ if let Some((code, count)) = result {
+ Ok(Some((code, count.try_into()?)))
+ } else {
+ Ok(None)
+ }
+ }
+
+ async fn redeem_invite_code(&self, code: &str, login: &str) -> Result<UserId> {
+ let mut tx = self.pool.begin().await?;
+
+ let inviter_id: UserId = sqlx::query_scalar(
+ "
+ UPDATE users
+ SET invite_count = invite_count - 1
+ WHERE
+ invite_code = $1 AND
+ invite_count > 0
+ RETURNING id
+ ",
+ )
+ .bind(code)
+ .fetch_optional(&mut tx)
+ .await?
+ .ok_or_else(|| anyhow!("invite code not found"))?;
+ let invitee_id = sqlx::query_scalar(
+ "
+ INSERT INTO users
+ (github_login, admin, inviter_id)
+ VALUES
+ ($1, 'f', $2)
+ RETURNING id
+ ",
+ )
+ .bind(login)
+ .bind(inviter_id)
+ .fetch_one(&mut tx)
+ .await
+ .map(UserId)?;
+
+ sqlx::query(
+ "
+ INSERT INTO contacts
+ (user_id_a, user_id_b, a_to_b, should_notify, accepted)
+ VALUES
+ ($1, $2, 't', 't', 't')
+ ",
+ )
+ .bind(inviter_id)
+ .bind(invitee_id)
+ .execute(&mut tx)
+ .await?;
+
+ tx.commit().await?;
+ Ok(invitee_id)
+ }
+
// contacts
async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
@@ -1198,6 +1300,144 @@ pub mod tests {
}
}
+ #[tokio::test(flavor = "multi_thread")]
+ async fn test_invite_codes() {
+ let postgres = TestDb::postgres().await;
+ let db = postgres.db();
+ let user1 = db.create_user("user-1", false).await.unwrap();
+
+ // Initially, user 1 has no invite code
+ assert_eq!(db.get_invite_code(user1).await.unwrap(), None);
+
+ // User 1 creates an invite code that can be used twice.
+ db.set_invite_count(user1, 2).await.unwrap();
+ let (invite_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+ assert_eq!(invite_count, 2);
+
+ // User 2 redeems the invite code and becomes a contact of user 1.
+ let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap();
+ let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+ assert_eq!(invite_count, 1);
+ assert_eq!(
+ db.get_contacts(user1).await.unwrap(),
+ [
+ Contact::Accepted {
+ user_id: user1,
+ should_notify: false
+ },
+ Contact::Accepted {
+ user_id: user2,
+ should_notify: true
+ }
+ ]
+ );
+ assert_eq!(
+ db.get_contacts(user2).await.unwrap(),
+ [
+ Contact::Accepted {
+ user_id: user1,
+ should_notify: false
+ },
+ Contact::Accepted {
+ user_id: user2,
+ should_notify: false
+ }
+ ]
+ );
+
+ // User 3 redeems the invite code and becomes a contact of user 1.
+ let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap();
+ let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+ assert_eq!(invite_count, 0);
+ assert_eq!(
+ db.get_contacts(user1).await.unwrap(),
+ [
+ Contact::Accepted {
+ user_id: user1,
+ should_notify: false
+ },
+ Contact::Accepted {
+ user_id: user2,
+ should_notify: true
+ },
+ Contact::Accepted {
+ user_id: user3,
+ should_notify: true
+ }
+ ]
+ );
+ assert_eq!(
+ db.get_contacts(user3).await.unwrap(),
+ [
+ Contact::Accepted {
+ user_id: user1,
+ should_notify: false
+ },
+ Contact::Accepted {
+ user_id: user3,
+ should_notify: false
+ },
+ ]
+ );
+
+ // Trying to reedem the code for the third time results in an error.
+ db.redeem_invite_code(&invite_code, "user-4")
+ .await
+ .unwrap_err();
+
+ // Invite count can be updated after the code has been created.
+ db.set_invite_count(user1, 2).await.unwrap();
+ let (latest_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+ assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
+ assert_eq!(invite_count, 2);
+
+ // User 4 can now redeem the invite code and becomes a contact of user 1.
+ let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap();
+ let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+ assert_eq!(invite_count, 1);
+ assert_eq!(
+ db.get_contacts(user1).await.unwrap(),
+ [
+ Contact::Accepted {
+ user_id: user1,
+ should_notify: false
+ },
+ Contact::Accepted {
+ user_id: user2,
+ should_notify: true
+ },
+ Contact::Accepted {
+ user_id: user3,
+ should_notify: true
+ },
+ Contact::Accepted {
+ user_id: user4,
+ should_notify: true
+ }
+ ]
+ );
+ assert_eq!(
+ db.get_contacts(user4).await.unwrap(),
+ [
+ Contact::Accepted {
+ user_id: user1,
+ should_notify: false
+ },
+ Contact::Accepted {
+ user_id: user4,
+ should_notify: false
+ },
+ ]
+ );
+
+ // An existing user cannot redeem invite codes.
+ db.redeem_invite_code(&invite_code, "user-2")
+ .await
+ .unwrap_err();
+ let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap();
+ assert_eq!(invite_count, 1);
+ }
+
pub struct TestDb {
pub db: Option<Arc<dyn Db>>,
pub url: String,
@@ -1348,6 +1588,22 @@ pub mod tests {
unimplemented!()
}
+ // invite codes
+
+ async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
+ unimplemented!()
+ }
+
+ async fn get_invite_code(&self, _id: UserId) -> Result<Option<(String, u32)>> {
+ unimplemented!()
+ }
+
+ async fn redeem_invite_code(&self, _code: &str, _login: &str) -> Result<UserId> {
+ unimplemented!()
+ }
+
+ // contacts
+
async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
self.background.simulate_random_delay().await;
let mut contacts = vec![Contact::Accepted {