diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index de8ec44c782eaa0309f8b0eb2dacc759efbe470e..73293e0b2c1f83e805560599b939f50fbc88b8a4 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -25,10 +25,7 @@ use tracing::instrument; pub fn routes(rpc_server: &Arc, state: Arc) -> Router { Router::new() .route("/users", get(get_users).post(create_user)) - .route( - "/users/:id", - put(update_user).delete(destroy_user).get(get_user), - ) + .route("/users/:id", put(update_user).delete(destroy_user)) .route("/users/:id/access_tokens", post(create_access_token)) .route("/users_with_no_invites", get(get_users_with_no_invites)) .route("/invite_codes/:code", get(get_user_for_invite_code)) @@ -90,6 +87,8 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR #[derive(Debug, Deserialize)] struct GetUsersQueryParams { + github_user_id: Option, + github_login: Option, query: Option, page: Option, limit: Option, @@ -99,6 +98,14 @@ async fn get_users( Query(params): Query, Extension(app): Extension>, ) -> Result>> { + if let Some(github_login) = ¶ms.github_login { + let user = app + .db + .get_user_by_github_account(github_login, params.github_user_id) + .await?; + return Ok(Json(Vec::from_iter(user))); + } + let limit = params.limit.unwrap_or(100); let users = if let Some(query) = params.query { app.db.fuzzy_search_users(&query, limit).await? @@ -205,18 +212,6 @@ async fn destroy_user( Ok(()) } -async fn get_user( - Path(login): Path, - Extension(app): Extension>, -) -> Result> { - let user = app - .db - .get_user_by_github_login(&login) - .await? - .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "User not found".to_string()))?; - Ok(Json(user)) -} - #[derive(Debug, Deserialize)] struct GetUsersWithNoInvites { invited_by_another_user: bool, @@ -351,22 +346,24 @@ struct CreateAccessTokenResponse { } async fn create_access_token( - Path(login): Path, + Path(user_id): Path, Query(params): Query, Extension(app): Extension>, ) -> Result> { - // request.require_token().await?; - let user = app .db - .get_user_by_github_login(&login) + .get_user_by_id(user_id) .await? .ok_or_else(|| anyhow!("user not found"))?; let mut user_id = user.id; if let Some(impersonate) = params.impersonate { if user.admin { - if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? { + if let Some(impersonated_user) = app + .db + .get_user_by_github_account(&impersonate, None) + .await? + { user_id = impersonated_user.id; } else { return Err(Error::Http( diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index f31defa57751514cade745cf4530b954ec1d4398..70dc0c4e5b565dad7e3faa7bc8c45adc4ccd8366 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -23,7 +23,11 @@ pub trait Db: Send + Sync { async fn get_user_by_id(&self, id: UserId) -> Result>; async fn get_users_by_ids(&self, ids: Vec) -> Result>; async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result>; - async fn get_user_by_github_login(&self, github_login: &str) -> Result>; + async fn get_user_by_github_account( + &self, + github_login: &str, + github_user_id: Option, + ) -> Result>; async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>; async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>; async fn destroy_user(&self, id: UserId) -> Result<()>; @@ -274,12 +278,53 @@ impl Db for PostgresDb { Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?) } - async fn get_user_by_github_login(&self, github_login: &str) -> Result> { - let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1"; - Ok(sqlx::query_as(query) + async fn get_user_by_github_account( + &self, + github_login: &str, + github_user_id: Option, + ) -> Result> { + if let Some(github_user_id) = github_user_id { + let mut user = sqlx::query_as::<_, User>( + " + UPDATE users + SET github_login = $1 + WHERE github_user_id = $2 + RETURNING * + ", + ) + .bind(github_login) + .bind(github_user_id) + .fetch_optional(&self.pool) + .await?; + + if user.is_none() { + user = sqlx::query_as::<_, User>( + " + UPDATE users + SET github_user_id = $1 + WHERE github_login = $2 + RETURNING * + ", + ) + .bind(github_user_id) + .bind(github_login) + .fetch_optional(&self.pool) + .await?; + } + + Ok(user) + } else { + Ok(sqlx::query_as( + " + SELECT * FROM users + WHERE github_login = $1 + LIMIT 1 + ", + ) .bind(github_login) .fetch_optional(&self.pool) .await?) + } } async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { @@ -1777,14 +1822,32 @@ mod test { unimplemented!() } - async fn get_user_by_github_login(&self, github_login: &str) -> Result> { + async fn get_user_by_github_account( + &self, + github_login: &str, + github_user_id: Option, + ) -> Result> { self.background.simulate_random_delay().await; - Ok(self - .users - .lock() - .values() - .find(|user| user.github_login == github_login) - .cloned()) + if let Some(github_user_id) = github_user_id { + for user in self.users.lock().values_mut() { + if user.github_user_id == github_user_id { + user.github_login = github_login.into(); + return Ok(Some(user.clone())); + } + if user.github_login == github_login { + user.github_user_id = github_user_id; + return Ok(Some(user.clone())); + } + } + Ok(None) + } else { + Ok(self + .users + .lock() + .values() + .find(|user| user.github_login == github_login) + .cloned()) + } } async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> { diff --git a/crates/collab/src/db_tests.rs b/crates/collab/src/db_tests.rs index 87033fab38f30128493dec2329c0b51f580a51b5..49ac053fd80124bbe4ad0878d477472d3cbdcfab 100644 --- a/crates/collab/src/db_tests.rs +++ b/crates/collab/src/db_tests.rs @@ -103,6 +103,64 @@ async fn test_get_users_by_ids() { } } +#[tokio::test(flavor = "multi_thread")] +async fn test_get_user_by_github_account() { + for test_db in [ + TestDb::postgres().await, + TestDb::fake(build_background_executor()), + ] { + let db = test_db.db(); + let user_id1 = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "login1".into(), + github_user_id: 101, + invite_count: 0, + }, + ) + .await + .unwrap(); + let user_id2 = db + .create_user( + "user2@example.com", + false, + NewUserParams { + github_login: "login2".into(), + github_user_id: 102, + invite_count: 0, + }, + ) + .await + .unwrap(); + + let user = db + .get_user_by_github_account("login1", None) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id1); + assert_eq!(&user.github_login, "login1"); + assert_eq!(user.github_user_id, 101); + + assert!(db + .get_user_by_github_account("non-existent-login", None) + .await + .unwrap() + .is_none()); + + let user = db + .get_user_by_github_account("the-new-login2", Some(102)) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id2); + assert_eq!(&user.github_login, "the-new-login2"); + assert_eq!(user.github_user_id, 102); + } +} + #[tokio::test(flavor = "multi_thread")] async fn test_worktree_extensions() { let test_db = TestDb::postgres().await; diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 94811b095189e4cd5cd47c2d3aaacde21063d3e8..f3d43f277e457de0bde65ae2db7c51bba170934f 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -5173,17 +5173,25 @@ impl TestServer { }); let http = FakeHttpClient::with_404_response(); - let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await + let user_id = if let Ok(Some(user)) = self + .app_state + .db + .get_user_by_github_account(name, None) + .await { user.id } else { self.app_state .db - .create_user(&format!("{name}@example.com"), false, NewUserParams { - github_login: name.into(), - github_user_id: 0, - invite_count: 0, - }) + .create_user( + &format!("{name}@example.com"), + false, + NewUserParams { + github_login: name.into(), + github_user_id: 0, + invite_count: 0, + }, + ) .await .unwrap() }; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 4fc022995f2344173c90c18e67b45b3683d869e1..5f27352c5a08a83636e2d8f2fcfbd5c77099a226 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1404,7 +1404,7 @@ impl Server { let users = match query.len() { 0 => vec![], 1 | 2 => db - .get_user_by_github_login(&query) + .get_user_by_github_account(&query, None) .await? .into_iter() .collect(),