From d367d3fbbcc973eec2416fe1736899445853cccc Mon Sep 17 00:00:00 2001 From: Finn Evers Date: Tue, 14 Apr 2026 23:25:49 +0200 Subject: [PATCH] compliance: Make trait more flexible (#53914) This will make this easier to use with the GitHub worker. Release Notes: - N/A --- tooling/compliance/src/checks.rs | 59 ++++++---- tooling/compliance/src/github.rs | 153 ++++++++++++++++++++------ tooling/xtask/src/tasks/compliance.rs | 13 ++- 3 files changed, 165 insertions(+), 60 deletions(-) diff --git a/tooling/compliance/src/checks.rs b/tooling/compliance/src/checks.rs index 0ee8eed8081eaa5a4038cb89e933c215909d5029..9a3e9a6b208ea872efcad7db1abf8a16e80f8334 100644 --- a/tooling/compliance/src/checks.rs +++ b/tooling/compliance/src/checks.rs @@ -6,7 +6,7 @@ use crate::{ git::{CommitDetails, CommitList}, github::{ CommitAuthor, GitHubClient, GitHubUser, GithubLogin, PullRequestComment, PullRequestData, - PullRequestReview, ReviewState, + PullRequestReview, Repository, ReviewState, }, report::Report, }; @@ -118,7 +118,10 @@ impl<'a> Reporter<'a> { return Err(ReviewFailure::NoPullRequestFound); }; - let pull_request = self.github_client.get_pull_request(pr_number).await?; + let pull_request = self + .github_client + .get_pull_request(&Repository::ZED, pr_number) + .await?; if let Some(approval) = self .check_approving_pull_request_review(&pull_request) @@ -152,7 +155,7 @@ impl<'a> Reporter<'a> { if commit.co_authors().is_some() && let Some(commit_authors) = self .github_client - .get_commit_authors(&[commit.sha()]) + .get_commit_authors(&Repository::ZED, &[commit.sha()]) .await? .get(commit.sha()) .and_then(|authors| authors.co_authors()) @@ -162,7 +165,7 @@ impl<'a> Reporter<'a> { if let Some(github_login) = co_author.user() && self .github_client - .actor_has_repository_write_permission(github_login) + .check_repo_write_permission(&Repository::ZED, github_login) .await? { org_co_authors.push(co_author.clone()); @@ -186,7 +189,7 @@ impl<'a> Reporter<'a> { if let Some(user) = pull_request.user && self .github_client - .actor_has_repository_write_permission(&GithubLogin::new(user.login)) + .check_repo_write_permission(&Repository::ZED, &GithubLogin::new(user.login)) .await? .not() { @@ -209,7 +212,7 @@ impl<'a> Reporter<'a> { ) -> Result, ReviewFailure> { let pr_reviews = self .github_client - .get_pull_request_reviews(pull_request.number) + .get_pull_request_reviews(&Repository::ZED, pull_request.number) .await?; if !pr_reviews.is_empty() { @@ -229,9 +232,10 @@ impl<'a> Reporter<'a> { .is_some_and(Self::contains_approving_pattern)) && self .github_client - .actor_has_repository_write_permission(&GithubLogin::new( - github_login.login.clone(), - )) + .check_repo_write_permission( + &Repository::ZED, + &GithubLogin::new(github_login.login.clone()), + ) .await? { org_approving_reviews.push(review); @@ -253,7 +257,7 @@ impl<'a> Reporter<'a> { ) -> Result, ReviewFailure> { let other_comments = self .github_client - .get_pull_request_comments(pull_request.number) + .get_pull_request_comments(&Repository::ZED, pull_request.number) .await?; if !other_comments.is_empty() { @@ -270,9 +274,10 @@ impl<'a> Reporter<'a> { .is_some_and(Self::contains_approving_pattern) && self .github_client - .actor_has_repository_write_permission(&GithubLogin::new( - comment.user.login.clone(), - )) + .check_repo_write_permission( + &Repository::ZED, + &GithubLogin::new(comment.user.login.clone()), + ) .await? { org_approving_comments.push(comment); @@ -327,7 +332,7 @@ mod tests { use crate::git::{CommitDetails, CommitList, CommitSha}; use crate::github::{ AuthorsForCommits, GitHubApiClient, GitHubClient, GitHubUser, GithubLogin, - PullRequestComment, PullRequestData, PullRequestReview, ReviewState, + PullRequestComment, PullRequestData, PullRequestReview, Repository, ReviewState, }; use super::{Reporter, ReviewFailure, ReviewSuccess}; @@ -342,12 +347,17 @@ mod tests { #[async_trait::async_trait(?Send)] impl GitHubApiClient for MockGitHubApi { - async fn get_pull_request(&self, _pr_number: u64) -> anyhow::Result { + async fn get_pull_request( + &self, + _repo: &Repository<'_>, + _pr_number: u64, + ) -> anyhow::Result { Ok(self.pull_request.clone()) } async fn get_pull_request_reviews( &self, + _repo: &Repository<'_>, _pr_number: u64, ) -> anyhow::Result> { Ok(self.reviews.clone()) @@ -355,6 +365,7 @@ mod tests { async fn get_pull_request_comments( &self, + _repo: &Repository<'_>, _pr_number: u64, ) -> anyhow::Result> { Ok(self.comments.clone()) @@ -362,23 +373,29 @@ mod tests { async fn get_commit_authors( &self, + _repo: &Repository<'_>, _commit_shas: &[&CommitSha], ) -> anyhow::Result { serde_json::from_value(self.commit_authors_json.clone()).map_err(Into::into) } - async fn check_org_membership(&self, login: &GithubLogin) -> anyhow::Result { + async fn check_repo_write_permission( + &self, + _repo: &Repository<'_>, + login: &GithubLogin, + ) -> anyhow::Result { Ok(self .org_members .iter() .any(|member| member == login.as_str())) } - async fn check_repo_write_permission(&self, _login: &GithubLogin) -> anyhow::Result { - Ok(false) - } - - async fn add_label_to_issue(&self, _label: &str, _pr_number: u64) -> anyhow::Result<()> { + async fn add_label_to_issue( + &self, + _repo: &Repository<'_>, + _label: &str, + _pr_number: u64, + ) -> anyhow::Result<()> { Ok(()) } } diff --git a/tooling/compliance/src/github.rs b/tooling/compliance/src/github.rs index bb1b55358a1f67fc68809994a6a21e6b56e0a16d..69479b83806e4aa8c14532c1a27c453e72811e3f 100644 --- a/tooling/compliance/src/github.rs +++ b/tooling/compliance/src/github.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt, ops::Not, rc::Rc}; +use std::{borrow::Cow, collections::HashMap, fmt, ops::Not, rc::Rc}; use anyhow::Result; use derive_more::Deref; @@ -141,22 +141,73 @@ impl<'de> serde::Deserialize<'de> for AuthorsForCommits { } } +#[derive(Clone)] +pub struct Repository<'a> { + owner: Cow<'a, str>, + name: Cow<'a, str>, +} + +impl<'a> Repository<'a> { + pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed"); + + pub fn new(owner: &'a str, name: &'a str) -> Self { + Self { + owner: Cow::Borrowed(owner), + name: Cow::Borrowed(name), + } + } + + pub fn owner(&self) -> &str { + &self.owner + } + + pub fn name(&self) -> &str { + &self.name + } +} + +impl Repository<'static> { + pub const fn new_static(owner: &'static str, name: &'static str) -> Self { + Self { + owner: Cow::Borrowed(owner), + name: Cow::Borrowed(name), + } + } +} + #[async_trait::async_trait(?Send)] pub trait GitHubApiClient { - async fn get_pull_request(&self, pr_number: u64) -> Result; - async fn get_pull_request_reviews(&self, pr_number: u64) -> Result>; - async fn get_pull_request_comments(&self, pr_number: u64) -> Result>; - async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result; - async fn check_org_membership(&self, login: &GithubLogin) -> Result; - async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result; - async fn actor_has_repository_write_permission( + async fn get_pull_request( + &self, + repo: &Repository<'_>, + pr_number: u64, + ) -> Result; + async fn get_pull_request_reviews( &self, + repo: &Repository<'_>, + pr_number: u64, + ) -> Result>; + async fn get_pull_request_comments( + &self, + repo: &Repository<'_>, + pr_number: u64, + ) -> Result>; + async fn get_commit_authors( + &self, + repo: &Repository<'_>, + commit_shas: &[&CommitSha], + ) -> Result; + async fn check_repo_write_permission( + &self, + repo: &Repository<'_>, login: &GithubLogin, - ) -> anyhow::Result { - Ok(self.check_org_membership(login).await? - || self.check_repo_write_permission(login).await?) - } - async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()>; + ) -> Result; + async fn add_label_to_issue( + &self, + repo: &Repository<'_>, + label: &str, + issue_number: u64, + ) -> Result<()>; } #[derive(Deref)] @@ -170,8 +221,8 @@ impl GitHubClient { } #[cfg(feature = "octo-client")] - pub async fn for_app(app_id: u64, app_private_key: &str) -> Result { - let client = OctocrabClient::new(app_id, app_private_key).await?; + pub async fn for_app_in_repo(app_id: u64, app_private_key: &str, org: &str) -> Result { + let client = OctocrabClient::new(app_id, app_private_key, org).await?; Ok(Self::new(Rc::new(client))) } } @@ -276,7 +327,10 @@ mod octo_client { use serde::de::DeserializeOwned; use tokio::pin; - use crate::{git::CommitSha, github::graph_ql}; + use crate::{ + git::CommitSha, + github::{Repository, graph_ql}, + }; use super::{ AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment, @@ -284,15 +338,13 @@ mod octo_client { }; const PAGE_SIZE: u8 = 100; - const ORG: &str = "zed-industries"; - const REPO: &str = "zed"; pub struct OctocrabClient { client: Octocrab, } impl OctocrabClient { - pub async fn new(app_id: u64, app_private_key: &str) -> Result { + pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result { let octocrab = Octocrab::builder() .cache(InMemoryCache::new()) .app( @@ -311,7 +363,7 @@ mod octo_client { let installation_id = installations .into_iter() - .find(|installation| installation.account.login == ORG) + .find(|installation| installation.account.login == org) .context("Could not find Zed repository in installations")? .id; @@ -355,8 +407,16 @@ mod octo_client { #[async_trait::async_trait(?Send)] impl GitHubApiClient for OctocrabClient { - async fn get_pull_request(&self, pr_number: u64) -> Result { - let pr = self.client.pulls(ORG, REPO).get(pr_number).await?; + async fn get_pull_request( + &self, + repo: &Repository<'_>, + pr_number: u64, + ) -> Result { + let pr = self + .client + .pulls(repo.owner.as_ref(), repo.name.as_ref()) + .get(pr_number) + .await?; Ok(PullRequestData { number: pr.number, user: pr.user.map(|user| GitHubUser { login: user.login }), @@ -367,10 +427,14 @@ mod octo_client { }) } - async fn get_pull_request_reviews(&self, pr_number: u64) -> Result> { + async fn get_pull_request_reviews( + &self, + repo: &Repository<'_>, + pr_number: u64, + ) -> Result> { let page = self .client - .pulls(ORG, REPO) + .pulls(repo.owner.as_ref(), repo.name.as_ref()) .list_reviews(pr_number) .per_page(PAGE_SIZE) .send() @@ -393,11 +457,12 @@ mod octo_client { async fn get_pull_request_comments( &self, + repo: &Repository<'_>, pr_number: u64, ) -> Result> { let page = self .client - .issues(ORG, REPO) + .issues(repo.owner.as_ref(), repo.name.as_ref()) .list_comments(pr_number) .per_page(PAGE_SIZE) .send() @@ -418,19 +483,29 @@ mod octo_client { async fn get_commit_authors( &self, + repo: &Repository<'_>, commit_shas: &[&CommitSha], ) -> Result { - let query = graph_ql::build_co_authors_query(ORG, REPO, commit_shas.iter().copied()); + let query = graph_ql::build_co_authors_query( + repo.owner.as_ref(), + repo.name.as_ref(), + commit_shas.iter().copied(), + ); let query = serde_json::json!({ "query": query }); self.graphql::(&query) .await .map(|response| response.repository) } - async fn check_org_membership(&self, login: &GithubLogin) -> Result { + async fn check_repo_write_permission( + &self, + repo: &Repository<'_>, + login: &GithubLogin, + ) -> Result { + // Check org membership first - we save ourselves a few request that way let page = self .client - .orgs(ORG) + .orgs(repo.owner.as_ref()) .list_members() .per_page(PAGE_SIZE) .send() @@ -438,12 +513,13 @@ mod octo_client { let members = self.get_all(page).await?; - Ok(members + if members .into_iter() - .any(|member| member.login == login.as_str())) - } + .any(|member| member.login == login.as_str()) + { + return Ok(true); + } - async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result { // TODO: octocrab fails to deserialize the permission response and // does not adhere to the scheme laid out at // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user @@ -466,7 +542,9 @@ mod octo_client { self.client .get::( format!( - "/repos/{ORG}/{REPO}/collaborators/{user}/permission", + "/repos/{owner}/{repo}/collaborators/{user}/permission", + owner = repo.owner.as_ref(), + repo = repo.name.as_ref(), user = login.as_str() ), None::<&()>, @@ -481,9 +559,14 @@ mod octo_client { .map_err(Into::into) } - async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()> { + async fn add_label_to_issue( + &self, + repo: &Repository<'_>, + label: &str, + issue_number: u64, + ) -> Result<()> { self.client - .issues(ORG, REPO) + .issues(repo.owner.as_ref(), repo.name.as_ref()) .add_labels(issue_number, &[label.to_owned()]) .await .map(|_| ()) diff --git a/tooling/xtask/src/tasks/compliance.rs b/tooling/xtask/src/tasks/compliance.rs index 6d055665c6a5cc804841cc90c9af757517cbd14f..d5b98409be2bcb142158fc75de94eec2e3a1f4e5 100644 --- a/tooling/xtask/src/tasks/compliance.rs +++ b/tooling/xtask/src/tasks/compliance.rs @@ -6,7 +6,7 @@ use clap::Parser; use compliance::{ checks::Reporter, git::{CommitsFromVersionToVersion, GetVersionTags, GitCommand, VersionTag}, - github::GitHubClient, + github::{GitHubClient, Repository}, report::ReportReviewSummary, }; @@ -69,9 +69,10 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> { println!("Checking commit range {range}, {} total", commits.len()); - let client = GitHubClient::for_app( + let client = GitHubClient::for_app_in_repo( app_id.parse().context("Failed to parse app ID as int")?, key.as_ref(), + Repository::ZED.owner(), ) .await?; @@ -93,7 +94,7 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> { for report in report.errors() { if let Some(pr_number) = report.commit.pr_number() - && let Ok(pull_request) = client.get_pull_request(pr_number).await + && let Ok(pull_request) = client.get_pull_request(&Repository::ZED, pr_number).await && pull_request.labels.is_none_or(|labels| { labels .iter() @@ -103,7 +104,11 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> { println!("Adding review label to PR {}...", pr_number); client - .add_label_to_issue(compliance::github::PR_REVIEW_LABEL, pr_number) + .add_label_to_issue( + &Repository::ZED, + compliance::github::PR_REVIEW_LABEL, + pr_number, + ) .await?; } }