Update user retrieval API to take both github user id and github login

Max Brunsfeld created

Change summary

crates/collab/src/api.rs               | 39 +++++------
crates/collab/src/db.rs                | 85 ++++++++++++++++++++++++---
crates/collab/src/db_tests.rs          | 58 +++++++++++++++++++
crates/collab/src/integration_tests.rs | 20 ++++-
crates/collab/src/rpc.rs               |  2 
5 files changed, 165 insertions(+), 39 deletions(-)

Detailed changes

crates/collab/src/api.rs 🔗

@@ -25,10 +25,7 @@ use tracing::instrument;
 pub fn routes(rpc_server: &Arc<rpc::Server>, state: Arc<AppState>) -> Router<Body> {
     Router::new()
         .route("/users", get(get_users).post(create_user))
-        .route(
-            "/users/:id",
-            put(update_user).delete(destroy_user).get(get_user),
-        )
+        .route("/users/:id", put(update_user).delete(destroy_user))
         .route("/users/:id/access_tokens", post(create_access_token))
         .route("/users_with_no_invites", get(get_users_with_no_invites))
         .route("/invite_codes/:code", get(get_user_for_invite_code))
@@ -90,6 +87,8 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
 
 #[derive(Debug, Deserialize)]
 struct GetUsersQueryParams {
+    github_user_id: Option<i32>,
+    github_login: Option<String>,
     query: Option<String>,
     page: Option<u32>,
     limit: Option<u32>,
@@ -99,6 +98,14 @@ async fn get_users(
     Query(params): Query<GetUsersQueryParams>,
     Extension(app): Extension<Arc<AppState>>,
 ) -> Result<Json<Vec<User>>> {
+    if let Some(github_login) = &params.github_login {
+        let user = app
+            .db
+            .get_user_by_github_account(github_login, params.github_user_id)
+            .await?;
+        return Ok(Json(Vec::from_iter(user)));
+    }
+
     let limit = params.limit.unwrap_or(100);
     let users = if let Some(query) = params.query {
         app.db.fuzzy_search_users(&query, limit).await?
@@ -205,18 +212,6 @@ async fn destroy_user(
     Ok(())
 }
 
-async fn get_user(
-    Path(login): Path<String>,
-    Extension(app): Extension<Arc<AppState>>,
-) -> Result<Json<User>> {
-    let user = app
-        .db
-        .get_user_by_github_login(&login)
-        .await?
-        .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "User not found".to_string()))?;
-    Ok(Json(user))
-}
-
 #[derive(Debug, Deserialize)]
 struct GetUsersWithNoInvites {
     invited_by_another_user: bool,
@@ -351,22 +346,24 @@ struct CreateAccessTokenResponse {
 }
 
 async fn create_access_token(
-    Path(login): Path<String>,
+    Path(user_id): Path<UserId>,
     Query(params): Query<CreateAccessTokenQueryParams>,
     Extension(app): Extension<Arc<AppState>>,
 ) -> Result<Json<CreateAccessTokenResponse>> {
-    //     request.require_token().await?;
-
     let user = app
         .db
-        .get_user_by_github_login(&login)
+        .get_user_by_id(user_id)
         .await?
         .ok_or_else(|| anyhow!("user not found"))?;
 
     let mut user_id = user.id;
     if let Some(impersonate) = params.impersonate {
         if user.admin {
-            if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
+            if let Some(impersonated_user) = app
+                .db
+                .get_user_by_github_account(&impersonate, None)
+                .await?
+            {
                 user_id = impersonated_user.id;
             } else {
                 return Err(Error::Http(

crates/collab/src/db.rs 🔗

@@ -23,7 +23,11 @@ pub trait Db: Send + Sync {
     async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
     async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
     async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
-    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
+    async fn get_user_by_github_account(
+        &self,
+        github_login: &str,
+        github_user_id: Option<i32>,
+    ) -> Result<Option<User>>;
     async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
     async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
     async fn destroy_user(&self, id: UserId) -> Result<()>;
@@ -274,12 +278,53 @@ impl Db for PostgresDb {
         Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
     }
 
-    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
-        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
-        Ok(sqlx::query_as(query)
+    async fn get_user_by_github_account(
+        &self,
+        github_login: &str,
+        github_user_id: Option<i32>,
+    ) -> Result<Option<User>> {
+        if let Some(github_user_id) = github_user_id {
+            let mut user = sqlx::query_as::<_, User>(
+                "
+                UPDATE users
+                SET github_login = $1
+                WHERE github_user_id = $2
+                RETURNING *
+                ",
+            )
+            .bind(github_login)
+            .bind(github_user_id)
+            .fetch_optional(&self.pool)
+            .await?;
+
+            if user.is_none() {
+                user = sqlx::query_as::<_, User>(
+                    "
+                    UPDATE users
+                    SET github_user_id = $1
+                    WHERE github_login = $2
+                    RETURNING *
+                    ",
+                )
+                .bind(github_user_id)
+                .bind(github_login)
+                .fetch_optional(&self.pool)
+                .await?;
+            }
+
+            Ok(user)
+        } else {
+            Ok(sqlx::query_as(
+                "
+                SELECT * FROM users
+                WHERE github_login = $1
+                LIMIT 1
+                ",
+            )
             .bind(github_login)
             .fetch_optional(&self.pool)
             .await?)
+        }
     }
 
     async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
@@ -1777,14 +1822,32 @@ mod test {
             unimplemented!()
         }
 
-        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
+        async fn get_user_by_github_account(
+            &self,
+            github_login: &str,
+            github_user_id: Option<i32>,
+        ) -> Result<Option<User>> {
             self.background.simulate_random_delay().await;
-            Ok(self
-                .users
-                .lock()
-                .values()
-                .find(|user| user.github_login == github_login)
-                .cloned())
+            if let Some(github_user_id) = github_user_id {
+                for user in self.users.lock().values_mut() {
+                    if user.github_user_id == github_user_id {
+                        user.github_login = github_login.into();
+                        return Ok(Some(user.clone()));
+                    }
+                    if user.github_login == github_login {
+                        user.github_user_id = github_user_id;
+                        return Ok(Some(user.clone()));
+                    }
+                }
+                Ok(None)
+            } else {
+                Ok(self
+                    .users
+                    .lock()
+                    .values()
+                    .find(|user| user.github_login == github_login)
+                    .cloned())
+            }
         }
 
         async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {

crates/collab/src/db_tests.rs 🔗

@@ -103,6 +103,64 @@ async fn test_get_users_by_ids() {
     }
 }
 
+#[tokio::test(flavor = "multi_thread")]
+async fn test_get_user_by_github_account() {
+    for test_db in [
+        TestDb::postgres().await,
+        TestDb::fake(build_background_executor()),
+    ] {
+        let db = test_db.db();
+        let user_id1 = db
+            .create_user(
+                "user1@example.com",
+                false,
+                NewUserParams {
+                    github_login: "login1".into(),
+                    github_user_id: 101,
+                    invite_count: 0,
+                },
+            )
+            .await
+            .unwrap();
+        let user_id2 = db
+            .create_user(
+                "user2@example.com",
+                false,
+                NewUserParams {
+                    github_login: "login2".into(),
+                    github_user_id: 102,
+                    invite_count: 0,
+                },
+            )
+            .await
+            .unwrap();
+
+        let user = db
+            .get_user_by_github_account("login1", None)
+            .await
+            .unwrap()
+            .unwrap();
+        assert_eq!(user.id, user_id1);
+        assert_eq!(&user.github_login, "login1");
+        assert_eq!(user.github_user_id, 101);
+
+        assert!(db
+            .get_user_by_github_account("non-existent-login", None)
+            .await
+            .unwrap()
+            .is_none());
+
+        let user = db
+            .get_user_by_github_account("the-new-login2", Some(102))
+            .await
+            .unwrap()
+            .unwrap();
+        assert_eq!(user.id, user_id2);
+        assert_eq!(&user.github_login, "the-new-login2");
+        assert_eq!(user.github_user_id, 102);
+    }
+}
+
 #[tokio::test(flavor = "multi_thread")]
 async fn test_worktree_extensions() {
     let test_db = TestDb::postgres().await;

crates/collab/src/integration_tests.rs 🔗

@@ -5173,17 +5173,25 @@ impl TestServer {
         });
 
         let http = FakeHttpClient::with_404_response();
-        let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
+        let user_id = if let Ok(Some(user)) = self
+            .app_state
+            .db
+            .get_user_by_github_account(name, None)
+            .await
         {
             user.id
         } else {
             self.app_state
                 .db
-                .create_user(&format!("{name}@example.com"), false, NewUserParams {
-                    github_login: name.into(),
-                    github_user_id: 0,
-                    invite_count: 0,
-                })
+                .create_user(
+                    &format!("{name}@example.com"),
+                    false,
+                    NewUserParams {
+                        github_login: name.into(),
+                        github_user_id: 0,
+                        invite_count: 0,
+                    },
+                )
                 .await
                 .unwrap()
         };

crates/collab/src/rpc.rs 🔗

@@ -1404,7 +1404,7 @@ impl Server {
         let users = match query.len() {
             0 => vec![],
             1 | 2 => db
-                .get_user_by_github_login(&query)
+                .get_user_by_github_account(&query, None)
                 .await?
                 .into_iter()
                 .collect(),