@@ -375,11 +375,7 @@ mod tests {
Ok(false)
}
- async fn ensure_pull_request_has_label(
- &self,
- _label: &str,
- _pr_number: u64,
- ) -> anyhow::Result<()> {
+ async fn add_label_to_issue(&self, _label: &str, _pr_number: u64) -> anyhow::Result<()> {
Ok(())
}
}
@@ -439,6 +435,7 @@ mod tests {
login: "alice".to_owned(),
}),
merged_by: None,
+ labels: None,
},
reviews: vec![],
comments: vec![],
@@ -609,11 +606,11 @@ mod tests {
"email": "alice@test.com",
"user": { "login": "alice" }
},
- "authors": [{
+ "authors": { "nodes": [{
"name": "Charlie",
"email": "charlie@test.com",
"user": { "login": "charlie" }
- }]
+ }] }
}
}))
.with_commit(make_commit(
@@ -639,11 +636,11 @@ mod tests {
"email": "alice@test.com",
"user": { "login": "alice" }
},
- "authors": [{
+ "authors": { "nodes": [{
"name": "Bob",
"email": "bob@test.com",
"user": { "login": "bob" }
- }]
+ }] }
}
}))
.with_commit(make_commit(
@@ -18,6 +18,7 @@ pub struct PullRequestData {
pub number: u64,
pub user: Option<GitHubUser>,
pub merged_by: Option<GitHubUser>,
+ pub labels: Option<Vec<String>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -100,7 +101,7 @@ impl fmt::Display for CommitAuthor {
pub struct CommitAuthors {
#[serde(rename = "author")]
primary_author: CommitAuthor,
- #[serde(rename = "authors")]
+ #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
co_authors: Vec<CommitAuthor>,
}
@@ -114,9 +115,32 @@ impl CommitAuthors {
}
}
-#[derive(Debug, Deserialize, Deref)]
+#[derive(Debug, Deref)]
pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
+impl AuthorsForCommits {
+ const SHA_PREFIX: &'static str = "commit";
+}
+
+impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ let raw = HashMap::<String, CommitAuthors>::deserialize(deserializer)?;
+ let map = raw
+ .into_iter()
+ .map(|(key, value)| {
+ let sha = key
+ .strip_prefix(AuthorsForCommits::SHA_PREFIX)
+ .unwrap_or(&key);
+ (CommitSha::new(sha.to_owned()), value)
+ })
+ .collect();
+ Ok(Self(map))
+ }
+}
+
#[async_trait::async_trait(?Send)]
pub trait GitHubApiClient {
async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
@@ -132,7 +156,7 @@ pub trait GitHubApiClient {
Ok(self.check_org_membership(login).await?
|| self.check_repo_write_permission(login).await?)
}
- async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>;
+ async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()>;
}
#[derive(Deref)]
@@ -152,11 +176,98 @@ impl GitHubClient {
}
}
+pub mod graph_ql {
+ use anyhow::{Context as _, Result};
+ use itertools::Itertools as _;
+ use serde::Deserialize;
+
+ use crate::git::CommitSha;
+
+ use super::AuthorsForCommits;
+
+ #[derive(Debug, Deserialize)]
+ pub struct GraphQLResponse<T> {
+ pub data: Option<T>,
+ pub errors: Option<Vec<GraphQLError>>,
+ }
+
+ impl<T> GraphQLResponse<T> {
+ pub fn into_data(self) -> Result<T> {
+ if let Some(errors) = &self.errors {
+ if !errors.is_empty() {
+ let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
+ anyhow::bail!("GraphQL error: {messages}");
+ }
+ }
+
+ self.data.context("GraphQL response contained no data")
+ }
+ }
+
+ #[derive(Debug, Deserialize)]
+ pub struct GraphQLError {
+ pub message: String,
+ }
+
+ #[derive(Debug, Deserialize)]
+ pub struct CommitAuthorsResponse {
+ pub repository: AuthorsForCommits,
+ }
+
+ pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
+ where
+ T: Deserialize<'de>,
+ D: serde::Deserializer<'de>,
+ {
+ #[derive(Deserialize)]
+ struct Nodes<T> {
+ nodes: Vec<T>,
+ }
+ Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
+ }
+
+ pub fn build_co_authors_query<'a>(
+ org: &str,
+ repo: &str,
+ shas: impl IntoIterator<Item = &'a CommitSha>,
+ ) -> String {
+ const FRAGMENT: &str = r#"
+ ... on Commit {
+ author {
+ name
+ email
+ user { login }
+ }
+ authors(first: 10) {
+ nodes {
+ name
+ email
+ user { login }
+ }
+ }
+ }
+ "#;
+
+ let objects = shas
+ .into_iter()
+ .map(|commit_sha| {
+ format!(
+ "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
+ sha_prefix = AuthorsForCommits::SHA_PREFIX,
+ sha = **commit_sha,
+ )
+ })
+ .join("\n");
+
+ format!("{{ repository(owner: \"{org}\", name: \"{repo}\") {{ {objects} }} }}")
+ .replace("\n", "")
+ }
+}
+
#[cfg(feature = "octo-client")]
mod octo_client {
use anyhow::{Context, Result};
use futures::TryStreamExt as _;
- use itertools::Itertools;
use jsonwebtoken::EncodingKey;
use octocrab::{
Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
@@ -165,7 +276,7 @@ mod octo_client {
use serde::de::DeserializeOwned;
use tokio::pin;
- use crate::git::CommitSha;
+ use crate::{git::CommitSha, github::graph_ql};
use super::{
AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
@@ -208,43 +319,11 @@ mod octo_client {
Ok(Self { client })
}
- fn build_co_authors_query<'a>(shas: impl IntoIterator<Item = &'a CommitSha>) -> String {
- const FRAGMENT: &str = r#"
- ... on Commit {
- author {
- name
- email
- user { login }
- }
- authors(first: 10) {
- nodes {
- name
- email
- user { login }
- }
- }
- }
- "#;
-
- let objects: String = shas
- .into_iter()
- .map(|commit_sha| {
- format!(
- "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
- sha = **commit_sha
- )
- })
- .join("\n");
-
- format!("{{ repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects} }} }}")
- .replace("\n", "")
- }
-
- async fn graphql<R: octocrab::FromResponse>(
- &self,
- query: &serde_json::Value,
- ) -> octocrab::Result<R> {
- self.client.graphql(query).await
+ async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
+ let response: serde_json::Value = self.client.graphql(query).await?;
+ let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
+ .context("Failed to parse GraphQL response envelope")?;
+ parsed.into_data()
}
async fn get_all<T: DeserializeOwned + 'static>(
@@ -282,6 +361,9 @@ mod octo_client {
number: pr.number,
user: pr.user.map(|user| GitHubUser { login: user.login }),
merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
+ labels: pr
+ .labels
+ .map(|labels| labels.into_iter().map(|label| label.name).collect()),
})
}
@@ -338,32 +420,11 @@ mod octo_client {
&self,
commit_shas: &[&CommitSha],
) -> Result<AuthorsForCommits> {
- let query = Self::build_co_authors_query(commit_shas.iter().copied());
+ let query = graph_ql::build_co_authors_query(ORG, REPO, commit_shas.iter().copied());
let query = serde_json::json!({ "query": query });
- let mut response = self.graphql::<serde_json::Value>(&query).await?;
-
- response
- .get_mut("data")
- .and_then(|data| data.get_mut("repository"))
- .and_then(|repo| repo.as_object_mut())
- .ok_or_else(|| anyhow::anyhow!("Unexpected response format!"))
- .and_then(|commit_data| {
- let mut response_map = serde_json::Map::with_capacity(commit_data.len());
-
- for (key, value) in commit_data.iter_mut() {
- let key_without_prefix = key.strip_prefix("commit").unwrap_or(key);
- if let Some(authors) = value.get_mut("authors") {
- if let Some(nodes) = authors.get("nodes") {
- *authors = nodes.clone();
- }
- }
-
- response_map.insert(key_without_prefix.to_owned(), value.clone());
- }
-
- serde_json::from_value(serde_json::Value::Object(response_map))
- .context("Failed to deserialize commit authors")
- })
+ self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
+ .await
+ .map(|response| response.repository)
}
async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
@@ -420,27 +481,13 @@ mod octo_client {
.map_err(Into::into)
}
- async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> {
- if self
- .get_filtered(
- self.client
- .issues(ORG, REPO)
- .list_labels_for_issue(pr_number)
- .per_page(PAGE_SIZE)
- .send()
- .await?,
- |pr_label| pr_label.name == label,
- )
+ async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()> {
+ self.client
+ .issues(ORG, REPO)
+ .add_labels(issue_number, &[label.to_owned()])
.await
- .is_ok_and(|l| l.is_empty())
- {
- self.client
- .issues(ORG, REPO)
- .add_labels(pr_number, &[label.to_owned()])
- .await?;
- }
-
- Ok(())
+ .map(|_| ())
+ .map_err(Into::into)
}
}
}