diff --git a/server/src/db.rs b/server/src/db.rs index 28aadd0d6334793f1307dfded99a550696571ba2..c3e270bc8759a92b971ce7a73b8914be8c9e7ace 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -128,29 +128,46 @@ impl Db { requester_id: UserId, ids: impl Iterator, ) -> Result> { + let mut include_requester = false; + let ids = ids + .map(|id| { + if id == requester_id { + include_requester = true; + } + id.0 + }) + .collect::>(); + test_support!(self, { // Only return users that are in a common channel with the requesting user. + // Also allow the requesting user to return their own data, even if they aren't + // in any channels. let query = " - SELECT users.* + SELECT + users.* FROM - users LEFT JOIN channel_memberships - ON - channel_memberships.user_id = users.id + users, channel_memberships WHERE - users.id = $2 OR - ( - users.id = ANY ($1) AND - channel_memberships.channel_id IN ( - SELECT channel_id - FROM channel_memberships - WHERE channel_memberships.user_id = $2 - ) + users.id = ANY ($1) AND + channel_memberships.user_id = users.id AND + channel_memberships.channel_id IN ( + SELECT channel_id + FROM channel_memberships + WHERE channel_memberships.user_id = $2 ) + UNION + SELECT + users.* + FROM + users + WHERE + $3 AND users.id = $2 "; sqlx::query_as(query) - .bind(&ids.map(|id| id.0).collect::>()) + .bind(&ids) .bind(requester_id) + .bind(include_requester) .fetch_all(&self.pool) .await }) @@ -571,16 +588,84 @@ pub mod tests { async fn test_get_users_by_ids() { let test_db = TestDb::new(); let db = test_db.db(); - let user_id = db.create_user("user", false).await.unwrap(); + + let user = db.create_user("user", false).await.unwrap(); + let friend1 = db.create_user("friend-1", false).await.unwrap(); + let friend2 = db.create_user("friend-2", false).await.unwrap(); + let friend3 = db.create_user("friend-3", false).await.unwrap(); + let stranger = db.create_user("stranger", false).await.unwrap(); + + // A user can read their own info, even if they aren't in any channels. assert_eq!( - db.get_users_by_ids(user_id, Some(user_id).iter().copied()) + db.get_users_by_ids( + user, + [user, friend1, friend2, friend3, stranger].iter().copied() + ) + .await + .unwrap(), + vec![User { + id: user, + github_login: "user".to_string(), + admin: false, + },], + ); + + // A user can read the info of any other user who is in a shared channel + // with them. + let org = db.create_org("test org", "test-org").await.unwrap(); + let chan1 = db.create_org_channel(org, "channel-1").await.unwrap(); + let chan2 = db.create_org_channel(org, "channel-2").await.unwrap(); + let chan3 = db.create_org_channel(org, "channel-3").await.unwrap(); + + db.add_channel_member(chan1, user, false).await.unwrap(); + db.add_channel_member(chan2, user, false).await.unwrap(); + db.add_channel_member(chan1, friend1, false).await.unwrap(); + db.add_channel_member(chan1, friend2, false).await.unwrap(); + db.add_channel_member(chan2, friend2, false).await.unwrap(); + db.add_channel_member(chan2, friend3, false).await.unwrap(); + db.add_channel_member(chan3, stranger, false).await.unwrap(); + + assert_eq!( + db.get_users_by_ids( + user, + [user, friend1, friend2, friend3, stranger].iter().copied() + ) + .await + .unwrap(), + vec![ + User { + id: user, + github_login: "user".to_string(), + admin: false, + }, + User { + id: friend1, + github_login: "friend-1".to_string(), + admin: false, + }, + User { + id: friend2, + github_login: "friend-2".to_string(), + admin: false, + }, + User { + id: friend3, + github_login: "friend-3".to_string(), + admin: false, + } + ] + ); + + // The user's own info is only returned if they request it. + assert_eq!( + db.get_users_by_ids(user, [friend1].iter().copied()) .await .unwrap(), vec![User { - id: user_id, - github_login: "user".to_string(), + id: friend1, + github_login: "friend-1".to_string(), admin: false, - }] + },] ) }