compliance: Deserialize GraphQL responses better (#53817)

Finn Evers created

This will help a lot when dealing with this in the context of the Zed
Zippy bot.

Release Notes:

- N/A

Change summary

tooling/compliance/src/checks.rs      |  15 -
tooling/compliance/src/git.rs         |   4 
tooling/compliance/src/github.rs      | 221 +++++++++++++++++-----------
tooling/xtask/src/tasks/compliance.rs |  11 +
4 files changed, 153 insertions(+), 98 deletions(-)

Detailed changes

tooling/compliance/src/checks.rs 🔗

@@ -375,11 +375,7 @@ mod tests {
             Ok(false)
         }
 
-        async fn ensure_pull_request_has_label(
-            &self,
-            _label: &str,
-            _pr_number: u64,
-        ) -> anyhow::Result<()> {
+        async fn add_label_to_issue(&self, _label: &str, _pr_number: u64) -> anyhow::Result<()> {
             Ok(())
         }
     }
@@ -439,6 +435,7 @@ mod tests {
                         login: "alice".to_owned(),
                     }),
                     merged_by: None,
+                    labels: None,
                 },
                 reviews: vec![],
                 comments: vec![],
@@ -609,11 +606,11 @@ mod tests {
                         "email": "alice@test.com",
                         "user": { "login": "alice" }
                     },
-                    "authors": [{
+                    "authors": { "nodes": [{
                         "name": "Charlie",
                         "email": "charlie@test.com",
                         "user": { "login": "charlie" }
-                    }]
+                    }] }
                 }
             }))
             .with_commit(make_commit(
@@ -639,11 +636,11 @@ mod tests {
                         "email": "alice@test.com",
                         "user": { "login": "alice" }
                     },
-                    "authors": [{
+                    "authors": { "nodes": [{
                         "name": "Bob",
                         "email": "bob@test.com",
                         "user": { "login": "bob" }
-                    }]
+                    }] }
                 }
             }))
             .with_commit(make_commit(

tooling/compliance/src/git.rs 🔗

@@ -110,6 +110,10 @@ impl ToString for VersionTag {
 pub struct CommitSha(pub(crate) String);
 
 impl CommitSha {
+    pub fn new(sha: String) -> Self {
+        Self(sha)
+    }
+
     pub fn short(&self) -> &str {
         self.0.as_str().split_at(8).0
     }

tooling/compliance/src/github.rs 🔗

@@ -18,6 +18,7 @@ pub struct PullRequestData {
     pub number: u64,
     pub user: Option<GitHubUser>,
     pub merged_by: Option<GitHubUser>,
+    pub labels: Option<Vec<String>>,
 }
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -100,7 +101,7 @@ impl fmt::Display for CommitAuthor {
 pub struct CommitAuthors {
     #[serde(rename = "author")]
     primary_author: CommitAuthor,
-    #[serde(rename = "authors")]
+    #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
     co_authors: Vec<CommitAuthor>,
 }
 
@@ -114,9 +115,32 @@ impl CommitAuthors {
     }
 }
 
-#[derive(Debug, Deserialize, Deref)]
+#[derive(Debug, Deref)]
 pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
 
+impl AuthorsForCommits {
+    const SHA_PREFIX: &'static str = "commit";
+}
+
+impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        let raw = HashMap::<String, CommitAuthors>::deserialize(deserializer)?;
+        let map = raw
+            .into_iter()
+            .map(|(key, value)| {
+                let sha = key
+                    .strip_prefix(AuthorsForCommits::SHA_PREFIX)
+                    .unwrap_or(&key);
+                (CommitSha::new(sha.to_owned()), value)
+            })
+            .collect();
+        Ok(Self(map))
+    }
+}
+
 #[async_trait::async_trait(?Send)]
 pub trait GitHubApiClient {
     async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
@@ -132,7 +156,7 @@ pub trait GitHubApiClient {
         Ok(self.check_org_membership(login).await?
             || self.check_repo_write_permission(login).await?)
     }
-    async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>;
+    async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()>;
 }
 
 #[derive(Deref)]
@@ -152,11 +176,98 @@ impl GitHubClient {
     }
 }
 
+pub mod graph_ql {
+    use anyhow::{Context as _, Result};
+    use itertools::Itertools as _;
+    use serde::Deserialize;
+
+    use crate::git::CommitSha;
+
+    use super::AuthorsForCommits;
+
+    #[derive(Debug, Deserialize)]
+    pub struct GraphQLResponse<T> {
+        pub data: Option<T>,
+        pub errors: Option<Vec<GraphQLError>>,
+    }
+
+    impl<T> GraphQLResponse<T> {
+        pub fn into_data(self) -> Result<T> {
+            if let Some(errors) = &self.errors {
+                if !errors.is_empty() {
+                    let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
+                    anyhow::bail!("GraphQL error: {messages}");
+                }
+            }
+
+            self.data.context("GraphQL response contained no data")
+        }
+    }
+
+    #[derive(Debug, Deserialize)]
+    pub struct GraphQLError {
+        pub message: String,
+    }
+
+    #[derive(Debug, Deserialize)]
+    pub struct CommitAuthorsResponse {
+        pub repository: AuthorsForCommits,
+    }
+
+    pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
+    where
+        T: Deserialize<'de>,
+        D: serde::Deserializer<'de>,
+    {
+        #[derive(Deserialize)]
+        struct Nodes<T> {
+            nodes: Vec<T>,
+        }
+        Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
+    }
+
+    pub fn build_co_authors_query<'a>(
+        org: &str,
+        repo: &str,
+        shas: impl IntoIterator<Item = &'a CommitSha>,
+    ) -> String {
+        const FRAGMENT: &str = r#"
+            ... on Commit {
+                author {
+                    name
+                    email
+                    user { login }
+                }
+                authors(first: 10) {
+                    nodes {
+                        name
+                        email
+                        user { login }
+                    }
+                }
+            }
+        "#;
+
+        let objects = shas
+            .into_iter()
+            .map(|commit_sha| {
+                format!(
+                    "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
+                    sha_prefix = AuthorsForCommits::SHA_PREFIX,
+                    sha = **commit_sha,
+                )
+            })
+            .join("\n");
+
+        format!("{{  repository(owner: \"{org}\", name: \"{repo}\") {{ {objects}  }} }}")
+            .replace("\n", "")
+    }
+}
+
 #[cfg(feature = "octo-client")]
 mod octo_client {
     use anyhow::{Context, Result};
     use futures::TryStreamExt as _;
-    use itertools::Itertools;
     use jsonwebtoken::EncodingKey;
     use octocrab::{
         Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
@@ -165,7 +276,7 @@ mod octo_client {
     use serde::de::DeserializeOwned;
     use tokio::pin;
 
-    use crate::git::CommitSha;
+    use crate::{git::CommitSha, github::graph_ql};
 
     use super::{
         AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
@@ -208,43 +319,11 @@ mod octo_client {
             Ok(Self { client })
         }
 
-        fn build_co_authors_query<'a>(shas: impl IntoIterator<Item = &'a CommitSha>) -> String {
-            const FRAGMENT: &str = r#"
-                ... on Commit {
-                    author {
-                        name
-                        email
-                        user { login }
-                    }
-                    authors(first: 10) {
-                        nodes {
-                            name
-                            email
-                            user { login }
-                        }
-                    }
-                }
-            "#;
-
-            let objects: String = shas
-                .into_iter()
-                .map(|commit_sha| {
-                    format!(
-                        "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
-                        sha = **commit_sha
-                    )
-                })
-                .join("\n");
-
-            format!("{{  repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects}  }} }}")
-                .replace("\n", "")
-        }
-
-        async fn graphql<R: octocrab::FromResponse>(
-            &self,
-            query: &serde_json::Value,
-        ) -> octocrab::Result<R> {
-            self.client.graphql(query).await
+        async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
+            let response: serde_json::Value = self.client.graphql(query).await?;
+            let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
+                .context("Failed to parse GraphQL response envelope")?;
+            parsed.into_data()
         }
 
         async fn get_all<T: DeserializeOwned + 'static>(
@@ -282,6 +361,9 @@ mod octo_client {
                 number: pr.number,
                 user: pr.user.map(|user| GitHubUser { login: user.login }),
                 merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
+                labels: pr
+                    .labels
+                    .map(|labels| labels.into_iter().map(|label| label.name).collect()),
             })
         }
 
@@ -338,32 +420,11 @@ mod octo_client {
             &self,
             commit_shas: &[&CommitSha],
         ) -> Result<AuthorsForCommits> {
-            let query = Self::build_co_authors_query(commit_shas.iter().copied());
+            let query = graph_ql::build_co_authors_query(ORG, REPO, commit_shas.iter().copied());
             let query = serde_json::json!({ "query": query });
-            let mut response = self.graphql::<serde_json::Value>(&query).await?;
-
-            response
-                .get_mut("data")
-                .and_then(|data| data.get_mut("repository"))
-                .and_then(|repo| repo.as_object_mut())
-                .ok_or_else(|| anyhow::anyhow!("Unexpected response format!"))
-                .and_then(|commit_data| {
-                    let mut response_map = serde_json::Map::with_capacity(commit_data.len());
-
-                    for (key, value) in commit_data.iter_mut() {
-                        let key_without_prefix = key.strip_prefix("commit").unwrap_or(key);
-                        if let Some(authors) = value.get_mut("authors") {
-                            if let Some(nodes) = authors.get("nodes") {
-                                *authors = nodes.clone();
-                            }
-                        }
-
-                        response_map.insert(key_without_prefix.to_owned(), value.clone());
-                    }
-
-                    serde_json::from_value(serde_json::Value::Object(response_map))
-                        .context("Failed to deserialize commit authors")
-                })
+            self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
+                .await
+                .map(|response| response.repository)
         }
 
         async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
@@ -420,27 +481,13 @@ mod octo_client {
                 .map_err(Into::into)
         }
 
-        async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> {
-            if self
-                .get_filtered(
-                    self.client
-                        .issues(ORG, REPO)
-                        .list_labels_for_issue(pr_number)
-                        .per_page(PAGE_SIZE)
-                        .send()
-                        .await?,
-                    |pr_label| pr_label.name == label,
-                )
+        async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()> {
+            self.client
+                .issues(ORG, REPO)
+                .add_labels(issue_number, &[label.to_owned()])
                 .await
-                .is_ok_and(|l| l.is_empty())
-            {
-                self.client
-                    .issues(ORG, REPO)
-                    .add_labels(pr_number, &[label.to_owned()])
-                    .await?;
-            }
-
-            Ok(())
+                .map(|_| ())
+                .map_err(Into::into)
         }
     }
 }

tooling/xtask/src/tasks/compliance.rs 🔗

@@ -92,11 +92,18 @@ async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> {
     );
 
     for report in report.errors() {
-        if let Some(pr_number) = report.commit.pr_number() {
+        if let Some(pr_number) = report.commit.pr_number()
+            && let Ok(pull_request) = client.get_pull_request(pr_number).await
+            && pull_request.labels.is_none_or(|labels| {
+                labels
+                    .iter()
+                    .all(|label| label != compliance::github::PR_REVIEW_LABEL)
+            })
+        {
             println!("Adding review label to PR {}...", pr_number);
 
             client
-                .ensure_pull_request_has_label(compliance::github::PR_REVIEW_LABEL, pr_number)
+                .add_label_to_issue(compliance::github::PR_REVIEW_LABEL, pr_number)
                 .await?;
         }
     }