diff --git a/tooling/compliance/src/checks.rs b/tooling/compliance/src/checks.rs index e9d86e7aa3f1924c8b1e619f5890bc8673d3a66f..da8bb7d24d9b6301c4e1badc4f693ae98ff4ee56 100644 --- a/tooling/compliance/src/checks.rs +++ b/tooling/compliance/src/checks.rs @@ -375,11 +375,7 @@ mod tests { Ok(false) } - async fn ensure_pull_request_has_label( - &self, - _label: &str, - _pr_number: u64, - ) -> anyhow::Result<()> { + async fn add_label_to_issue(&self, _label: &str, _pr_number: u64) -> anyhow::Result<()> { Ok(()) } } @@ -439,6 +435,7 @@ mod tests { login: "alice".to_owned(), }), merged_by: None, + labels: None, }, reviews: vec![], comments: vec![], @@ -609,11 +606,11 @@ mod tests { "email": "alice@test.com", "user": { "login": "alice" } }, - "authors": [{ + "authors": { "nodes": [{ "name": "Charlie", "email": "charlie@test.com", "user": { "login": "charlie" } - }] + }] } } })) .with_commit(make_commit( @@ -639,11 +636,11 @@ mod tests { "email": "alice@test.com", "user": { "login": "alice" } }, - "authors": [{ + "authors": { "nodes": [{ "name": "Bob", "email": "bob@test.com", "user": { "login": "bob" } - }] + }] } } })) .with_commit(make_commit( diff --git a/tooling/compliance/src/git.rs b/tooling/compliance/src/git.rs index b1b74fcd8f256f51b9743cdc16bf154deeae3f60..424032d035ad65034f1fc38dc173fd0d1fd8211e 100644 --- a/tooling/compliance/src/git.rs +++ b/tooling/compliance/src/git.rs @@ -110,6 +110,10 @@ impl ToString for VersionTag { pub struct CommitSha(pub(crate) String); impl CommitSha { + pub fn new(sha: String) -> Self { + Self(sha) + } + pub fn short(&self) -> &str { self.0.as_str().split_at(8).0 } diff --git a/tooling/compliance/src/github.rs b/tooling/compliance/src/github.rs index 96a786ccb2a1a16f31688bf3a87f1d862c195838..bb1b55358a1f67fc68809994a6a21e6b56e0a16d 100644 --- a/tooling/compliance/src/github.rs +++ b/tooling/compliance/src/github.rs @@ -18,6 +18,7 @@ pub struct PullRequestData { pub number: u64, pub user: Option, pub merged_by: Option, + pub labels: Option>, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -100,7 +101,7 @@ impl fmt::Display for CommitAuthor { pub struct CommitAuthors { #[serde(rename = "author")] primary_author: CommitAuthor, - #[serde(rename = "authors")] + #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")] co_authors: Vec, } @@ -114,9 +115,32 @@ impl CommitAuthors { } } -#[derive(Debug, Deserialize, Deref)] +#[derive(Debug, Deref)] pub struct AuthorsForCommits(HashMap); +impl AuthorsForCommits { + const SHA_PREFIX: &'static str = "commit"; +} + +impl<'de> serde::Deserialize<'de> for AuthorsForCommits { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + let raw = HashMap::::deserialize(deserializer)?; + let map = raw + .into_iter() + .map(|(key, value)| { + let sha = key + .strip_prefix(AuthorsForCommits::SHA_PREFIX) + .unwrap_or(&key); + (CommitSha::new(sha.to_owned()), value) + }) + .collect(); + Ok(Self(map)) + } +} + #[async_trait::async_trait(?Send)] pub trait GitHubApiClient { async fn get_pull_request(&self, pr_number: u64) -> Result; @@ -132,7 +156,7 @@ pub trait GitHubApiClient { Ok(self.check_org_membership(login).await? || self.check_repo_write_permission(login).await?) } - async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>; + async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()>; } #[derive(Deref)] @@ -152,11 +176,98 @@ impl GitHubClient { } } +pub mod graph_ql { + use anyhow::{Context as _, Result}; + use itertools::Itertools as _; + use serde::Deserialize; + + use crate::git::CommitSha; + + use super::AuthorsForCommits; + + #[derive(Debug, Deserialize)] + pub struct GraphQLResponse { + pub data: Option, + pub errors: Option>, + } + + impl GraphQLResponse { + pub fn into_data(self) -> Result { + if let Some(errors) = &self.errors { + if !errors.is_empty() { + let messages: String = errors.iter().map(|e| e.message.as_str()).join("; "); + anyhow::bail!("GraphQL error: {messages}"); + } + } + + self.data.context("GraphQL response contained no data") + } + } + + #[derive(Debug, Deserialize)] + pub struct GraphQLError { + pub message: String, + } + + #[derive(Debug, Deserialize)] + pub struct CommitAuthorsResponse { + pub repository: AuthorsForCommits, + } + + pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result, D::Error> + where + T: Deserialize<'de>, + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct Nodes { + nodes: Vec, + } + Nodes::::deserialize(deserializer).map(|wrapper| wrapper.nodes) + } + + pub fn build_co_authors_query<'a>( + org: &str, + repo: &str, + shas: impl IntoIterator, + ) -> String { + const FRAGMENT: &str = r#" + ... on Commit { + author { + name + email + user { login } + } + authors(first: 10) { + nodes { + name + email + user { login } + } + } + } + "#; + + let objects = shas + .into_iter() + .map(|commit_sha| { + format!( + "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}", + sha_prefix = AuthorsForCommits::SHA_PREFIX, + sha = **commit_sha, + ) + }) + .join("\n"); + + format!("{{ repository(owner: \"{org}\", name: \"{repo}\") {{ {objects} }} }}") + .replace("\n", "") + } +} + #[cfg(feature = "octo-client")] mod octo_client { use anyhow::{Context, Result}; use futures::TryStreamExt as _; - use itertools::Itertools; use jsonwebtoken::EncodingKey; use octocrab::{ Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState, @@ -165,7 +276,7 @@ mod octo_client { use serde::de::DeserializeOwned; use tokio::pin; - use crate::git::CommitSha; + use crate::{git::CommitSha, github::graph_ql}; use super::{ AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment, @@ -208,43 +319,11 @@ mod octo_client { Ok(Self { client }) } - fn build_co_authors_query<'a>(shas: impl IntoIterator) -> String { - const FRAGMENT: &str = r#" - ... on Commit { - author { - name - email - user { login } - } - authors(first: 10) { - nodes { - name - email - user { login } - } - } - } - "#; - - let objects: String = shas - .into_iter() - .map(|commit_sha| { - format!( - "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}", - sha = **commit_sha - ) - }) - .join("\n"); - - format!("{{ repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects} }} }}") - .replace("\n", "") - } - - async fn graphql( - &self, - query: &serde_json::Value, - ) -> octocrab::Result { - self.client.graphql(query).await + async fn graphql(&self, query: &serde_json::Value) -> Result { + let response: serde_json::Value = self.client.graphql(query).await?; + let parsed: graph_ql::GraphQLResponse = serde_json::from_value(response) + .context("Failed to parse GraphQL response envelope")?; + parsed.into_data() } async fn get_all( @@ -282,6 +361,9 @@ mod octo_client { number: pr.number, user: pr.user.map(|user| GitHubUser { login: user.login }), merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }), + labels: pr + .labels + .map(|labels| labels.into_iter().map(|label| label.name).collect()), }) } @@ -338,32 +420,11 @@ mod octo_client { &self, commit_shas: &[&CommitSha], ) -> Result { - let query = Self::build_co_authors_query(commit_shas.iter().copied()); + let query = graph_ql::build_co_authors_query(ORG, REPO, commit_shas.iter().copied()); let query = serde_json::json!({ "query": query }); - let mut response = self.graphql::(&query).await?; - - response - .get_mut("data") - .and_then(|data| data.get_mut("repository")) - .and_then(|repo| repo.as_object_mut()) - .ok_or_else(|| anyhow::anyhow!("Unexpected response format!")) - .and_then(|commit_data| { - let mut response_map = serde_json::Map::with_capacity(commit_data.len()); - - for (key, value) in commit_data.iter_mut() { - let key_without_prefix = key.strip_prefix("commit").unwrap_or(key); - if let Some(authors) = value.get_mut("authors") { - if let Some(nodes) = authors.get("nodes") { - *authors = nodes.clone(); - } - } - - response_map.insert(key_without_prefix.to_owned(), value.clone()); - } - - serde_json::from_value(serde_json::Value::Object(response_map)) - .context("Failed to deserialize commit authors") - }) + self.graphql::(&query) + .await + .map(|response| response.repository) } async fn check_org_membership(&self, login: &GithubLogin) -> Result { @@ -420,27 +481,13 @@ mod octo_client { .map_err(Into::into) } - async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> { - if self - .get_filtered( - self.client - .issues(ORG, REPO) - .list_labels_for_issue(pr_number) - .per_page(PAGE_SIZE) - .send() - .await?, - |pr_label| pr_label.name == label, - ) + async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()> { + self.client + .issues(ORG, REPO) + .add_labels(issue_number, &[label.to_owned()]) .await - .is_ok_and(|l| l.is_empty()) - { - self.client - .issues(ORG, REPO) - .add_labels(pr_number, &[label.to_owned()]) - .await?; - } - - Ok(()) + .map(|_| ()) + .map_err(Into::into) } } } diff --git a/tooling/xtask/src/tasks/compliance.rs b/tooling/xtask/src/tasks/compliance.rs index 43d56361719a3df894d690a05c80578edfaccb41..6d055665c6a5cc804841cc90c9af757517cbd14f 100644 --- a/tooling/xtask/src/tasks/compliance.rs +++ b/tooling/xtask/src/tasks/compliance.rs @@ -92,11 +92,18 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> { ); for report in report.errors() { - if let Some(pr_number) = report.commit.pr_number() { + if let Some(pr_number) = report.commit.pr_number() + && let Ok(pull_request) = client.get_pull_request(pr_number).await + && pull_request.labels.is_none_or(|labels| { + labels + .iter() + .all(|label| label != compliance::github::PR_REVIEW_LABEL) + }) + { println!("Adding review label to PR {}...", pr_number); client - .ensure_pull_request_has_label(compliance::github::PR_REVIEW_LABEL, pr_number) + .add_label_to_issue(compliance::github::PR_REVIEW_LABEL, pr_number) .await?; } }