Merge pull request #1174 from zed-industries/bulk-user-creation

Antonio Scandurra created

Expose a new `POST /api/bulk_users` API to create many users at once

Change summary

crates/collab/src/api.rs |  37 ++++++++++++++
crates/collab/src/db.rs  | 106 +++++++++++++++++++++++++++++++++++++++++
2 files changed, 142 insertions(+), 1 deletion(-)

Detailed changes

crates/collab/src/api.rs 🔗

@@ -28,6 +28,7 @@ pub fn routes(rpc_server: &Arc<rpc::Server>, state: Arc<AppState>) -> Router<Bod
             put(update_user).delete(destroy_user).get(get_user),
         )
         .route("/users/:id/access_tokens", post(create_access_token))
+        .route("/bulk_users", post(create_users))
         .route("/invite_codes/:code", get(get_user_for_invite_code))
         .route("/panic", post(trace_panic))
         .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
@@ -167,6 +168,42 @@ async fn get_user(
     Ok(Json(user))
 }
 
+#[derive(Deserialize)]
+struct CreateUsersParams {
+    users: Vec<CreateUsersEntry>,
+}
+
+#[derive(Deserialize)]
+struct CreateUsersEntry {
+    github_login: String,
+    email_address: String,
+    invite_count: usize,
+}
+
+async fn create_users(
+    Json(params): Json<CreateUsersParams>,
+    Extension(app): Extension<Arc<AppState>>,
+) -> Result<Json<Vec<User>>> {
+    let user_ids = app
+        .db
+        .create_users(
+            params
+                .users
+                .into_iter()
+                .map(|params| {
+                    (
+                        params.github_login,
+                        params.email_address,
+                        params.invite_count,
+                    )
+                })
+                .collect(),
+        )
+        .await?;
+    let users = app.db.get_users_by_ids(user_ids).await?;
+    Ok(Json(users))
+}
+
 #[derive(Debug, Deserialize)]
 struct Panic {
     version: String,

crates/collab/src/db.rs 🔗

@@ -6,7 +6,7 @@ use futures::StreamExt;
 use nanoid::nanoid;
 use serde::Serialize;
 pub use sqlx::postgres::PgPoolOptions as DbOptions;
-use sqlx::{types::Uuid, FromRow};
+use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
 use time::OffsetDateTime;
 
 #[async_trait]
@@ -17,6 +17,7 @@ pub trait Db: Send + Sync {
         email_address: Option<&str>,
         admin: bool,
     ) -> Result<UserId>;
+    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
     async fn get_all_users(&self) -> Result<Vec<User>>;
     async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
     async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
@@ -141,6 +142,41 @@ impl Db for PostgresDb {
             .map(UserId)?)
     }
 
+    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
+        let mut query = QueryBuilder::new(
+            "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
+        );
+        query.push_values(
+            users,
+            |mut query, (github_login, email_address, invite_count)| {
+                query
+                    .push_bind(github_login)
+                    .push_bind(email_address)
+                    .push_bind(false)
+                    .push_bind(nanoid!(16))
+                    .push_bind(invite_count as u32);
+            },
+        );
+        query.push(
+            "
+            ON CONFLICT (github_login) DO UPDATE SET
+                github_login = excluded.github_login,
+                invite_count = excluded.invite_count,
+                invite_code = CASE WHEN users.invite_code IS NULL
+                                   THEN excluded.invite_code
+                                   ELSE users.invite_code
+                              END
+            RETURNING id
+            ",
+        );
+
+        let rows = query.build().fetch_all(&self.pool).await?;
+        Ok(rows
+            .into_iter()
+            .filter_map(|row| row.try_get::<UserId, _>(0).ok())
+            .collect())
+    }
+
     async fn get_all_users(&self) -> Result<Vec<User>> {
         let query = "SELECT * FROM users ORDER BY github_login ASC";
         Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
@@ -1021,6 +1057,70 @@ pub mod tests {
         }
     }
 
+    #[tokio::test(flavor = "multi_thread")]
+    async fn test_create_users() {
+        let db = TestDb::postgres().await;
+        let db = db.db();
+
+        // Create the first batch of users, ensuring invite counts are assigned
+        // correctly and the respective invite codes are unique.
+        let user_ids_batch_1 = db
+            .create_users(vec![
+                ("user1".to_string(), "hi@user1.com".to_string(), 5),
+                ("user2".to_string(), "hi@user2.com".to_string(), 4),
+                ("user3".to_string(), "hi@user3.com".to_string(), 3),
+            ])
+            .await
+            .unwrap();
+        assert_eq!(user_ids_batch_1.len(), 3);
+
+        let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
+        assert_eq!(users.len(), 3);
+        assert_eq!(users[0].github_login, "user1");
+        assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
+        assert_eq!(users[0].invite_count, 5);
+        assert_eq!(users[1].github_login, "user2");
+        assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
+        assert_eq!(users[1].invite_count, 4);
+        assert_eq!(users[2].github_login, "user3");
+        assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
+        assert_eq!(users[2].invite_count, 3);
+
+        let invite_code_1 = users[0].invite_code.clone().unwrap();
+        let invite_code_2 = users[1].invite_code.clone().unwrap();
+        let invite_code_3 = users[2].invite_code.clone().unwrap();
+        assert_ne!(invite_code_1, invite_code_2);
+        assert_ne!(invite_code_1, invite_code_3);
+        assert_ne!(invite_code_2, invite_code_3);
+
+        // Create the second batch of users and include a user that is already in the database, ensuring
+        // the invite count for the existing user is updated without changing their invite code.
+        let user_ids_batch_2 = db
+            .create_users(vec![
+                ("user2".to_string(), "hi@user2.com".to_string(), 10),
+                ("user4".to_string(), "hi@user4.com".to_string(), 2),
+            ])
+            .await
+            .unwrap();
+        assert_eq!(user_ids_batch_2.len(), 2);
+        assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
+
+        let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
+        assert_eq!(users.len(), 2);
+        assert_eq!(users[0].github_login, "user2");
+        assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
+        assert_eq!(users[0].invite_count, 10);
+        assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
+        assert_eq!(users[1].github_login, "user4");
+        assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
+        assert_eq!(users[1].invite_count, 2);
+
+        let invite_code_4 = users[1].invite_code.clone().unwrap();
+        assert_ne!(invite_code_4, invite_code_1);
+        assert_ne!(invite_code_4, invite_code_2);
+        assert_ne!(invite_code_4, invite_code_3);
+    }
+
     #[tokio::test(flavor = "multi_thread")]
     async fn test_recent_channel_messages() {
         for test_db in [
@@ -1665,6 +1765,10 @@ pub mod tests {
             }
         }
 
+        async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
+            unimplemented!()
+        }
+
         async fn get_all_users(&self) -> Result<Vec<User>> {
             unimplemented!()
         }