Fix duplicated results in get_users_by_ids

Max Brunsfeld created

Change summary

server/src/db.rs | 121 ++++++++++++++++++++++++++++++++++++++++++-------
1 file changed, 103 insertions(+), 18 deletions(-)

Detailed changes

server/src/db.rs 🔗

@@ -128,29 +128,46 @@ impl Db {
         requester_id: UserId,
         ids: impl Iterator<Item = UserId>,
     ) -> Result<Vec<User>> {
+        let mut include_requester = false;
+        let ids = ids
+            .map(|id| {
+                if id == requester_id {
+                    include_requester = true;
+                }
+                id.0
+            })
+            .collect::<Vec<_>>();
+
         test_support!(self, {
             // Only return users that are in a common channel with the requesting user.
+            // Also allow the requesting user to return their own data, even if they aren't
+            // in any channels.
             let query = "
-                SELECT users.*
+                SELECT
+                    users.*
                 FROM
-                    users LEFT JOIN channel_memberships
-                ON
-                    channel_memberships.user_id = users.id
+                    users, channel_memberships
                 WHERE
-                    users.id = $2 OR
-                    (
-                        users.id = ANY ($1) AND
-                        channel_memberships.channel_id IN (
-                            SELECT channel_id
-                            FROM channel_memberships
-                            WHERE channel_memberships.user_id = $2
-                        )
+                    users.id = ANY ($1) AND
+                    channel_memberships.user_id = users.id AND
+                    channel_memberships.channel_id IN (
+                        SELECT channel_id
+                        FROM channel_memberships
+                        WHERE channel_memberships.user_id = $2
                     )
+                UNION
+                SELECT
+                    users.*
+                FROM
+                    users
+                WHERE
+                    $3 AND users.id = $2
             ";
 
             sqlx::query_as(query)
-                .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
+                .bind(&ids)
                 .bind(requester_id)
+                .bind(include_requester)
                 .fetch_all(&self.pool)
                 .await
         })
@@ -571,16 +588,84 @@ pub mod tests {
     async fn test_get_users_by_ids() {
         let test_db = TestDb::new();
         let db = test_db.db();
-        let user_id = db.create_user("user", false).await.unwrap();
+
+        let user = db.create_user("user", false).await.unwrap();
+        let friend1 = db.create_user("friend-1", false).await.unwrap();
+        let friend2 = db.create_user("friend-2", false).await.unwrap();
+        let friend3 = db.create_user("friend-3", false).await.unwrap();
+        let stranger = db.create_user("stranger", false).await.unwrap();
+
+        // A user can read their own info, even if they aren't in any channels.
         assert_eq!(
-            db.get_users_by_ids(user_id, Some(user_id).iter().copied())
+            db.get_users_by_ids(
+                user,
+                [user, friend1, friend2, friend3, stranger].iter().copied()
+            )
+            .await
+            .unwrap(),
+            vec![User {
+                id: user,
+                github_login: "user".to_string(),
+                admin: false,
+            },],
+        );
+
+        // A user can read the info of any other user who is in a shared channel
+        // with them.
+        let org = db.create_org("test org", "test-org").await.unwrap();
+        let chan1 = db.create_org_channel(org, "channel-1").await.unwrap();
+        let chan2 = db.create_org_channel(org, "channel-2").await.unwrap();
+        let chan3 = db.create_org_channel(org, "channel-3").await.unwrap();
+
+        db.add_channel_member(chan1, user, false).await.unwrap();
+        db.add_channel_member(chan2, user, false).await.unwrap();
+        db.add_channel_member(chan1, friend1, false).await.unwrap();
+        db.add_channel_member(chan1, friend2, false).await.unwrap();
+        db.add_channel_member(chan2, friend2, false).await.unwrap();
+        db.add_channel_member(chan2, friend3, false).await.unwrap();
+        db.add_channel_member(chan3, stranger, false).await.unwrap();
+
+        assert_eq!(
+            db.get_users_by_ids(
+                user,
+                [user, friend1, friend2, friend3, stranger].iter().copied()
+            )
+            .await
+            .unwrap(),
+            vec![
+                User {
+                    id: user,
+                    github_login: "user".to_string(),
+                    admin: false,
+                },
+                User {
+                    id: friend1,
+                    github_login: "friend-1".to_string(),
+                    admin: false,
+                },
+                User {
+                    id: friend2,
+                    github_login: "friend-2".to_string(),
+                    admin: false,
+                },
+                User {
+                    id: friend3,
+                    github_login: "friend-3".to_string(),
+                    admin: false,
+                }
+            ]
+        );
+
+        // The user's own info is only returned if they request it.
+        assert_eq!(
+            db.get_users_by_ids(user, [friend1].iter().copied())
                 .await
                 .unwrap(),
             vec![User {
-                id: user_id,
-                github_login: "user".to_string(),
+                id: friend1,
+                github_login: "friend-1".to_string(),
                 admin: false,
-            }]
+            },]
         )
     }