compliance: Make trait more flexible (#53914)

Finn Evers created

This will make this easier to use with the GitHub worker.

Release Notes:

- N/A

Change summary

tooling/compliance/src/checks.rs      |  59 +++++++---
tooling/compliance/src/github.rs      | 153 ++++++++++++++++++++++------
tooling/xtask/src/tasks/compliance.rs |  13 +
3 files changed, 165 insertions(+), 60 deletions(-)

Detailed changes

tooling/compliance/src/checks.rs 🔗

@@ -6,7 +6,7 @@ use crate::{
     git::{CommitDetails, CommitList},
     github::{
         CommitAuthor, GitHubClient, GitHubUser, GithubLogin, PullRequestComment, PullRequestData,
-        PullRequestReview, ReviewState,
+        PullRequestReview, Repository, ReviewState,
     },
     report::Report,
 };
@@ -118,7 +118,10 @@ impl<'a> Reporter<'a> {
             return Err(ReviewFailure::NoPullRequestFound);
         };
 
-        let pull_request = self.github_client.get_pull_request(pr_number).await?;
+        let pull_request = self
+            .github_client
+            .get_pull_request(&Repository::ZED, pr_number)
+            .await?;
 
         if let Some(approval) = self
             .check_approving_pull_request_review(&pull_request)
@@ -152,7 +155,7 @@ impl<'a> Reporter<'a> {
         if commit.co_authors().is_some()
             && let Some(commit_authors) = self
                 .github_client
-                .get_commit_authors(&[commit.sha()])
+                .get_commit_authors(&Repository::ZED, &[commit.sha()])
                 .await?
                 .get(commit.sha())
                 .and_then(|authors| authors.co_authors())
@@ -162,7 +165,7 @@ impl<'a> Reporter<'a> {
                 if let Some(github_login) = co_author.user()
                     && self
                         .github_client
-                        .actor_has_repository_write_permission(github_login)
+                        .check_repo_write_permission(&Repository::ZED, github_login)
                         .await?
                 {
                     org_co_authors.push(co_author.clone());
@@ -186,7 +189,7 @@ impl<'a> Reporter<'a> {
         if let Some(user) = pull_request.user
             && self
                 .github_client
-                .actor_has_repository_write_permission(&GithubLogin::new(user.login))
+                .check_repo_write_permission(&Repository::ZED, &GithubLogin::new(user.login))
                 .await?
                 .not()
         {
@@ -209,7 +212,7 @@ impl<'a> Reporter<'a> {
     ) -> Result<Option<ReviewSuccess>, ReviewFailure> {
         let pr_reviews = self
             .github_client
-            .get_pull_request_reviews(pull_request.number)
+            .get_pull_request_reviews(&Repository::ZED, pull_request.number)
             .await?;
 
         if !pr_reviews.is_empty() {
@@ -229,9 +232,10 @@ impl<'a> Reporter<'a> {
                             .is_some_and(Self::contains_approving_pattern))
                     && self
                         .github_client
-                        .actor_has_repository_write_permission(&GithubLogin::new(
-                            github_login.login.clone(),
-                        ))
+                        .check_repo_write_permission(
+                            &Repository::ZED,
+                            &GithubLogin::new(github_login.login.clone()),
+                        )
                         .await?
                 {
                     org_approving_reviews.push(review);
@@ -253,7 +257,7 @@ impl<'a> Reporter<'a> {
     ) -> Result<Option<ReviewSuccess>, ReviewFailure> {
         let other_comments = self
             .github_client
-            .get_pull_request_comments(pull_request.number)
+            .get_pull_request_comments(&Repository::ZED, pull_request.number)
             .await?;
 
         if !other_comments.is_empty() {
@@ -270,9 +274,10 @@ impl<'a> Reporter<'a> {
                         .is_some_and(Self::contains_approving_pattern)
                     && self
                         .github_client
-                        .actor_has_repository_write_permission(&GithubLogin::new(
-                            comment.user.login.clone(),
-                        ))
+                        .check_repo_write_permission(
+                            &Repository::ZED,
+                            &GithubLogin::new(comment.user.login.clone()),
+                        )
                         .await?
                 {
                     org_approving_comments.push(comment);
@@ -327,7 +332,7 @@ mod tests {
     use crate::git::{CommitDetails, CommitList, CommitSha};
     use crate::github::{
         AuthorsForCommits, GitHubApiClient, GitHubClient, GitHubUser, GithubLogin,
-        PullRequestComment, PullRequestData, PullRequestReview, ReviewState,
+        PullRequestComment, PullRequestData, PullRequestReview, Repository, ReviewState,
     };
 
     use super::{Reporter, ReviewFailure, ReviewSuccess};
@@ -342,12 +347,17 @@ mod tests {
 
     #[async_trait::async_trait(?Send)]
     impl GitHubApiClient for MockGitHubApi {
-        async fn get_pull_request(&self, _pr_number: u64) -> anyhow::Result<PullRequestData> {
+        async fn get_pull_request(
+            &self,
+            _repo: &Repository<'_>,
+            _pr_number: u64,
+        ) -> anyhow::Result<PullRequestData> {
             Ok(self.pull_request.clone())
         }
 
         async fn get_pull_request_reviews(
             &self,
+            _repo: &Repository<'_>,
             _pr_number: u64,
         ) -> anyhow::Result<Vec<PullRequestReview>> {
             Ok(self.reviews.clone())
@@ -355,6 +365,7 @@ mod tests {
 
         async fn get_pull_request_comments(
             &self,
+            _repo: &Repository<'_>,
             _pr_number: u64,
         ) -> anyhow::Result<Vec<PullRequestComment>> {
             Ok(self.comments.clone())
@@ -362,23 +373,29 @@ mod tests {
 
         async fn get_commit_authors(
             &self,
+            _repo: &Repository<'_>,
             _commit_shas: &[&CommitSha],
         ) -> anyhow::Result<AuthorsForCommits> {
             serde_json::from_value(self.commit_authors_json.clone()).map_err(Into::into)
         }
 
-        async fn check_org_membership(&self, login: &GithubLogin) -> anyhow::Result<bool> {
+        async fn check_repo_write_permission(
+            &self,
+            _repo: &Repository<'_>,
+            login: &GithubLogin,
+        ) -> anyhow::Result<bool> {
             Ok(self
                 .org_members
                 .iter()
                 .any(|member| member == login.as_str()))
         }
 
-        async fn check_repo_write_permission(&self, _login: &GithubLogin) -> anyhow::Result<bool> {
-            Ok(false)
-        }
-
-        async fn add_label_to_issue(&self, _label: &str, _pr_number: u64) -> anyhow::Result<()> {
+        async fn add_label_to_issue(
+            &self,
+            _repo: &Repository<'_>,
+            _label: &str,
+            _pr_number: u64,
+        ) -> anyhow::Result<()> {
             Ok(())
         }
     }

tooling/compliance/src/github.rs 🔗

@@ -1,4 +1,4 @@
-use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
+use std::{borrow::Cow, collections::HashMap, fmt, ops::Not, rc::Rc};
 
 use anyhow::Result;
 use derive_more::Deref;
@@ -141,22 +141,73 @@ impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
     }
 }
 
+#[derive(Clone)]
+pub struct Repository<'a> {
+    owner: Cow<'a, str>,
+    name: Cow<'a, str>,
+}
+
+impl<'a> Repository<'a> {
+    pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
+
+    pub fn new(owner: &'a str, name: &'a str) -> Self {
+        Self {
+            owner: Cow::Borrowed(owner),
+            name: Cow::Borrowed(name),
+        }
+    }
+
+    pub fn owner(&self) -> &str {
+        &self.owner
+    }
+
+    pub fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+impl Repository<'static> {
+    pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
+        Self {
+            owner: Cow::Borrowed(owner),
+            name: Cow::Borrowed(name),
+        }
+    }
+}
+
 #[async_trait::async_trait(?Send)]
 pub trait GitHubApiClient {
-    async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
-    async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
-    async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
-    async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
-    async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
-    async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool>;
-    async fn actor_has_repository_write_permission(
+    async fn get_pull_request(
+        &self,
+        repo: &Repository<'_>,
+        pr_number: u64,
+    ) -> Result<PullRequestData>;
+    async fn get_pull_request_reviews(
         &self,
+        repo: &Repository<'_>,
+        pr_number: u64,
+    ) -> Result<Vec<PullRequestReview>>;
+    async fn get_pull_request_comments(
+        &self,
+        repo: &Repository<'_>,
+        pr_number: u64,
+    ) -> Result<Vec<PullRequestComment>>;
+    async fn get_commit_authors(
+        &self,
+        repo: &Repository<'_>,
+        commit_shas: &[&CommitSha],
+    ) -> Result<AuthorsForCommits>;
+    async fn check_repo_write_permission(
+        &self,
+        repo: &Repository<'_>,
         login: &GithubLogin,
-    ) -> anyhow::Result<bool> {
-        Ok(self.check_org_membership(login).await?
-            || self.check_repo_write_permission(login).await?)
-    }
-    async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()>;
+    ) -> Result<bool>;
+    async fn add_label_to_issue(
+        &self,
+        repo: &Repository<'_>,
+        label: &str,
+        issue_number: u64,
+    ) -> Result<()>;
 }
 
 #[derive(Deref)]
@@ -170,8 +221,8 @@ impl GitHubClient {
     }
 
     #[cfg(feature = "octo-client")]
-    pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
-        let client = OctocrabClient::new(app_id, app_private_key).await?;
+    pub async fn for_app_in_repo(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
+        let client = OctocrabClient::new(app_id, app_private_key, org).await?;
         Ok(Self::new(Rc::new(client)))
     }
 }
@@ -276,7 +327,10 @@ mod octo_client {
     use serde::de::DeserializeOwned;
     use tokio::pin;
 
-    use crate::{git::CommitSha, github::graph_ql};
+    use crate::{
+        git::CommitSha,
+        github::{Repository, graph_ql},
+    };
 
     use super::{
         AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
@@ -284,15 +338,13 @@ mod octo_client {
     };
 
     const PAGE_SIZE: u8 = 100;
-    const ORG: &str = "zed-industries";
-    const REPO: &str = "zed";
 
     pub struct OctocrabClient {
         client: Octocrab,
     }
 
     impl OctocrabClient {
-        pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
+        pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
             let octocrab = Octocrab::builder()
                 .cache(InMemoryCache::new())
                 .app(
@@ -311,7 +363,7 @@ mod octo_client {
 
             let installation_id = installations
                 .into_iter()
-                .find(|installation| installation.account.login == ORG)
+                .find(|installation| installation.account.login == org)
                 .context("Could not find Zed repository in installations")?
                 .id;
 
@@ -355,8 +407,16 @@ mod octo_client {
 
     #[async_trait::async_trait(?Send)]
     impl GitHubApiClient for OctocrabClient {
-        async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
-            let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
+        async fn get_pull_request(
+            &self,
+            repo: &Repository<'_>,
+            pr_number: u64,
+        ) -> Result<PullRequestData> {
+            let pr = self
+                .client
+                .pulls(repo.owner.as_ref(), repo.name.as_ref())
+                .get(pr_number)
+                .await?;
             Ok(PullRequestData {
                 number: pr.number,
                 user: pr.user.map(|user| GitHubUser { login: user.login }),
@@ -367,10 +427,14 @@ mod octo_client {
             })
         }
 
-        async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
+        async fn get_pull_request_reviews(
+            &self,
+            repo: &Repository<'_>,
+            pr_number: u64,
+        ) -> Result<Vec<PullRequestReview>> {
             let page = self
                 .client
-                .pulls(ORG, REPO)
+                .pulls(repo.owner.as_ref(), repo.name.as_ref())
                 .list_reviews(pr_number)
                 .per_page(PAGE_SIZE)
                 .send()
@@ -393,11 +457,12 @@ mod octo_client {
 
         async fn get_pull_request_comments(
             &self,
+            repo: &Repository<'_>,
             pr_number: u64,
         ) -> Result<Vec<PullRequestComment>> {
             let page = self
                 .client
-                .issues(ORG, REPO)
+                .issues(repo.owner.as_ref(), repo.name.as_ref())
                 .list_comments(pr_number)
                 .per_page(PAGE_SIZE)
                 .send()
@@ -418,19 +483,29 @@ mod octo_client {
 
         async fn get_commit_authors(
             &self,
+            repo: &Repository<'_>,
             commit_shas: &[&CommitSha],
         ) -> Result<AuthorsForCommits> {
-            let query = graph_ql::build_co_authors_query(ORG, REPO, commit_shas.iter().copied());
+            let query = graph_ql::build_co_authors_query(
+                repo.owner.as_ref(),
+                repo.name.as_ref(),
+                commit_shas.iter().copied(),
+            );
             let query = serde_json::json!({ "query": query });
             self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
                 .await
                 .map(|response| response.repository)
         }
 
-        async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
+        async fn check_repo_write_permission(
+            &self,
+            repo: &Repository<'_>,
+            login: &GithubLogin,
+        ) -> Result<bool> {
+            // Check org membership first - we save ourselves a few request that way
             let page = self
                 .client
-                .orgs(ORG)
+                .orgs(repo.owner.as_ref())
                 .list_members()
                 .per_page(PAGE_SIZE)
                 .send()
@@ -438,12 +513,13 @@ mod octo_client {
 
             let members = self.get_all(page).await?;
 
-            Ok(members
+            if members
                 .into_iter()
-                .any(|member| member.login == login.as_str()))
-        }
+                .any(|member| member.login == login.as_str())
+            {
+                return Ok(true);
+            }
 
-        async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool> {
             // TODO: octocrab fails to deserialize the permission response and
             // does not adhere to the scheme laid out at
             // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
@@ -466,7 +542,9 @@ mod octo_client {
             self.client
                 .get::<RepositoryPermissions, _, _>(
                     format!(
-                        "/repos/{ORG}/{REPO}/collaborators/{user}/permission",
+                        "/repos/{owner}/{repo}/collaborators/{user}/permission",
+                        owner = repo.owner.as_ref(),
+                        repo = repo.name.as_ref(),
                         user = login.as_str()
                     ),
                     None::<&()>,
@@ -481,9 +559,14 @@ mod octo_client {
                 .map_err(Into::into)
         }
 
-        async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()> {
+        async fn add_label_to_issue(
+            &self,
+            repo: &Repository<'_>,
+            label: &str,
+            issue_number: u64,
+        ) -> Result<()> {
             self.client
-                .issues(ORG, REPO)
+                .issues(repo.owner.as_ref(), repo.name.as_ref())
                 .add_labels(issue_number, &[label.to_owned()])
                 .await
                 .map(|_| ())

tooling/xtask/src/tasks/compliance.rs 🔗

@@ -6,7 +6,7 @@ use clap::Parser;
 use compliance::{
     checks::Reporter,
     git::{CommitsFromVersionToVersion, GetVersionTags, GitCommand, VersionTag},
-    github::GitHubClient,
+    github::{GitHubClient, Repository},
     report::ReportReviewSummary,
 };
 
@@ -69,9 +69,10 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> {
 
     println!("Checking commit range {range}, {} total", commits.len());
 
-    let client = GitHubClient::for_app(
+    let client = GitHubClient::for_app_in_repo(
         app_id.parse().context("Failed to parse app ID as int")?,
         key.as_ref(),
+        Repository::ZED.owner(),
     )
     .await?;
 
@@ -93,7 +94,7 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> {
 
     for report in report.errors() {
         if let Some(pr_number) = report.commit.pr_number()
-            && let Ok(pull_request) = client.get_pull_request(pr_number).await
+            && let Ok(pull_request) = client.get_pull_request(&Repository::ZED, pr_number).await
             && pull_request.labels.is_none_or(|labels| {
                 labels
                     .iter()
@@ -103,7 +104,11 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> {
             println!("Adding review label to PR {}...", pr_number);
 
             client
-                .add_label_to_issue(compliance::github::PR_REVIEW_LABEL, pr_number)
+                .add_label_to_issue(
+                    &Repository::ZED,
+                    compliance::github::PR_REVIEW_LABEL,
+                    pr_number,
+                )
                 .await?;
         }
     }