github.rs

  1use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
  2
  3use anyhow::Result;
  4use derive_more::Deref;
  5use serde::Deserialize;
  6
  7use crate::git::CommitSha;
  8
  9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
 10
 11#[derive(Debug, Clone)]
 12pub struct GitHubUser {
 13    pub login: String,
 14}
 15
 16#[derive(Debug, Clone)]
 17pub struct PullRequestData {
 18    pub number: u64,
 19    pub user: Option<GitHubUser>,
 20    pub merged_by: Option<GitHubUser>,
 21    pub labels: Option<Vec<String>>,
 22}
 23
 24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 25pub enum ReviewState {
 26    Approved,
 27    Other,
 28}
 29
 30#[derive(Debug, Clone)]
 31pub struct PullRequestReview {
 32    pub user: Option<GitHubUser>,
 33    pub state: Option<ReviewState>,
 34    pub body: Option<String>,
 35}
 36
 37impl PullRequestReview {
 38    pub fn with_body(self, body: impl ToString) -> Self {
 39        Self {
 40            body: Some(body.to_string()),
 41            ..self
 42        }
 43    }
 44}
 45
 46#[derive(Debug, Clone)]
 47pub struct PullRequestComment {
 48    pub user: GitHubUser,
 49    pub body: Option<String>,
 50}
 51
 52#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
 53pub struct GithubLogin {
 54    login: String,
 55}
 56
 57impl GithubLogin {
 58    pub fn new(login: String) -> Self {
 59        Self { login }
 60    }
 61}
 62
 63impl fmt::Display for GithubLogin {
 64    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 65        write!(formatter, "@{}", self.login)
 66    }
 67}
 68
 69#[derive(Debug, Deserialize, Clone)]
 70pub struct CommitAuthor {
 71    name: String,
 72    email: String,
 73    user: Option<GithubLogin>,
 74}
 75
 76impl CommitAuthor {
 77    pub(crate) fn user(&self) -> Option<&GithubLogin> {
 78        self.user.as_ref()
 79    }
 80}
 81
 82impl PartialEq for CommitAuthor {
 83    fn eq(&self, other: &Self) -> bool {
 84        self.user.as_ref().zip(other.user.as_ref()).map_or_else(
 85            || self.email == other.email || self.name == other.name,
 86            |(l, r)| l == r,
 87        )
 88    }
 89}
 90
 91impl fmt::Display for CommitAuthor {
 92    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 93        match self.user.as_ref() {
 94            Some(user) => write!(formatter, "{} ({user})", self.name),
 95            None => write!(formatter, "{} ({})", self.name, self.email),
 96        }
 97    }
 98}
 99
100#[derive(Debug, Deserialize)]
101pub struct CommitAuthors {
102    #[serde(rename = "author")]
103    primary_author: CommitAuthor,
104    #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
105    co_authors: Vec<CommitAuthor>,
106}
107
108impl CommitAuthors {
109    pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
110        self.co_authors.is_empty().not().then(|| {
111            self.co_authors
112                .iter()
113                .filter(|co_author| *co_author != &self.primary_author)
114        })
115    }
116}
117
118#[derive(Debug, Deref)]
119pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
120
121impl AuthorsForCommits {
122    const SHA_PREFIX: &'static str = "commit";
123}
124
125impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
126    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
127    where
128        D: serde::Deserializer<'de>,
129    {
130        let raw = HashMap::<String, CommitAuthors>::deserialize(deserializer)?;
131        let map = raw
132            .into_iter()
133            .map(|(key, value)| {
134                let sha = key
135                    .strip_prefix(AuthorsForCommits::SHA_PREFIX)
136                    .unwrap_or(&key);
137                (CommitSha::new(sha.to_owned()), value)
138            })
139            .collect();
140        Ok(Self(map))
141    }
142}
143
144#[async_trait::async_trait(?Send)]
145pub trait GitHubApiClient {
146    async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
147    async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
148    async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
149    async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
150    async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
151    async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool>;
152    async fn actor_has_repository_write_permission(
153        &self,
154        login: &GithubLogin,
155    ) -> anyhow::Result<bool> {
156        Ok(self.check_org_membership(login).await?
157            || self.check_repo_write_permission(login).await?)
158    }
159    async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()>;
160}
161
162#[derive(Deref)]
163pub struct GitHubClient {
164    api: Rc<dyn GitHubApiClient>,
165}
166
167impl GitHubClient {
168    pub fn new(api: Rc<dyn GitHubApiClient>) -> Self {
169        Self { api }
170    }
171
172    #[cfg(feature = "octo-client")]
173    pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
174        let client = OctocrabClient::new(app_id, app_private_key).await?;
175        Ok(Self::new(Rc::new(client)))
176    }
177}
178
179pub mod graph_ql {
180    use anyhow::{Context as _, Result};
181    use itertools::Itertools as _;
182    use serde::Deserialize;
183
184    use crate::git::CommitSha;
185
186    use super::AuthorsForCommits;
187
188    #[derive(Debug, Deserialize)]
189    pub struct GraphQLResponse<T> {
190        pub data: Option<T>,
191        pub errors: Option<Vec<GraphQLError>>,
192    }
193
194    impl<T> GraphQLResponse<T> {
195        pub fn into_data(self) -> Result<T> {
196            if let Some(errors) = &self.errors {
197                if !errors.is_empty() {
198                    let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
199                    anyhow::bail!("GraphQL error: {messages}");
200                }
201            }
202
203            self.data.context("GraphQL response contained no data")
204        }
205    }
206
207    #[derive(Debug, Deserialize)]
208    pub struct GraphQLError {
209        pub message: String,
210    }
211
212    #[derive(Debug, Deserialize)]
213    pub struct CommitAuthorsResponse {
214        pub repository: AuthorsForCommits,
215    }
216
217    pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
218    where
219        T: Deserialize<'de>,
220        D: serde::Deserializer<'de>,
221    {
222        #[derive(Deserialize)]
223        struct Nodes<T> {
224            nodes: Vec<T>,
225        }
226        Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
227    }
228
229    pub fn build_co_authors_query<'a>(
230        org: &str,
231        repo: &str,
232        shas: impl IntoIterator<Item = &'a CommitSha>,
233    ) -> String {
234        const FRAGMENT: &str = r#"
235            ... on Commit {
236                author {
237                    name
238                    email
239                    user { login }
240                }
241                authors(first: 10) {
242                    nodes {
243                        name
244                        email
245                        user { login }
246                    }
247                }
248            }
249        "#;
250
251        let objects = shas
252            .into_iter()
253            .map(|commit_sha| {
254                format!(
255                    "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
256                    sha_prefix = AuthorsForCommits::SHA_PREFIX,
257                    sha = **commit_sha,
258                )
259            })
260            .join("\n");
261
262        format!("{{  repository(owner: \"{org}\", name: \"{repo}\") {{ {objects}  }} }}")
263            .replace("\n", "")
264    }
265}
266
267#[cfg(feature = "octo-client")]
268mod octo_client {
269    use anyhow::{Context, Result};
270    use futures::TryStreamExt as _;
271    use jsonwebtoken::EncodingKey;
272    use octocrab::{
273        Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
274        service::middleware::cache::mem::InMemoryCache,
275    };
276    use serde::de::DeserializeOwned;
277    use tokio::pin;
278
279    use crate::{git::CommitSha, github::graph_ql};
280
281    use super::{
282        AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
283        PullRequestData, PullRequestReview, ReviewState,
284    };
285
286    const PAGE_SIZE: u8 = 100;
287    const ORG: &str = "zed-industries";
288    const REPO: &str = "zed";
289
290    pub struct OctocrabClient {
291        client: Octocrab,
292    }
293
294    impl OctocrabClient {
295        pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
296            let octocrab = Octocrab::builder()
297                .cache(InMemoryCache::new())
298                .app(
299                    app_id.into(),
300                    EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
301                )
302                .build()?;
303
304            let installations = octocrab
305                .apps()
306                .installations()
307                .send()
308                .await
309                .context("Failed to fetch installations")?
310                .take_items();
311
312            let installation_id = installations
313                .into_iter()
314                .find(|installation| installation.account.login == ORG)
315                .context("Could not find Zed repository in installations")?
316                .id;
317
318            let client = octocrab.installation(installation_id)?;
319            Ok(Self { client })
320        }
321
322        async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
323            let response: serde_json::Value = self.client.graphql(query).await?;
324            let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
325                .context("Failed to parse GraphQL response envelope")?;
326            parsed.into_data()
327        }
328
329        async fn get_all<T: DeserializeOwned + 'static>(
330            &self,
331            page: Page<T>,
332        ) -> octocrab::Result<Vec<T>> {
333            self.get_filtered(page, |_| true).await
334        }
335
336        async fn get_filtered<T: DeserializeOwned + 'static>(
337            &self,
338            page: Page<T>,
339            predicate: impl Fn(&T) -> bool,
340        ) -> octocrab::Result<Vec<T>> {
341            let stream = page.into_stream(&self.client);
342            pin!(stream);
343
344            let mut results = Vec::new();
345
346            while let Some(item) = stream.try_next().await?
347                && predicate(&item)
348            {
349                results.push(item);
350            }
351
352            Ok(results)
353        }
354    }
355
356    #[async_trait::async_trait(?Send)]
357    impl GitHubApiClient for OctocrabClient {
358        async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
359            let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
360            Ok(PullRequestData {
361                number: pr.number,
362                user: pr.user.map(|user| GitHubUser { login: user.login }),
363                merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
364                labels: pr
365                    .labels
366                    .map(|labels| labels.into_iter().map(|label| label.name).collect()),
367            })
368        }
369
370        async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
371            let page = self
372                .client
373                .pulls(ORG, REPO)
374                .list_reviews(pr_number)
375                .per_page(PAGE_SIZE)
376                .send()
377                .await?;
378
379            let reviews = self.get_all(page).await?;
380
381            Ok(reviews
382                .into_iter()
383                .map(|review| PullRequestReview {
384                    user: review.user.map(|user| GitHubUser { login: user.login }),
385                    state: review.state.map(|state| match state {
386                        OctocrabReviewState::Approved => ReviewState::Approved,
387                        _ => ReviewState::Other,
388                    }),
389                    body: review.body,
390                })
391                .collect())
392        }
393
394        async fn get_pull_request_comments(
395            &self,
396            pr_number: u64,
397        ) -> Result<Vec<PullRequestComment>> {
398            let page = self
399                .client
400                .issues(ORG, REPO)
401                .list_comments(pr_number)
402                .per_page(PAGE_SIZE)
403                .send()
404                .await?;
405
406            let comments = self.get_all(page).await?;
407
408            Ok(comments
409                .into_iter()
410                .map(|comment| PullRequestComment {
411                    user: GitHubUser {
412                        login: comment.user.login,
413                    },
414                    body: comment.body,
415                })
416                .collect())
417        }
418
419        async fn get_commit_authors(
420            &self,
421            commit_shas: &[&CommitSha],
422        ) -> Result<AuthorsForCommits> {
423            let query = graph_ql::build_co_authors_query(ORG, REPO, commit_shas.iter().copied());
424            let query = serde_json::json!({ "query": query });
425            self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
426                .await
427                .map(|response| response.repository)
428        }
429
430        async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
431            let page = self
432                .client
433                .orgs(ORG)
434                .list_members()
435                .per_page(PAGE_SIZE)
436                .send()
437                .await?;
438
439            let members = self.get_all(page).await?;
440
441            Ok(members
442                .into_iter()
443                .any(|member| member.login == login.as_str()))
444        }
445
446        async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool> {
447            // TODO: octocrab fails to deserialize the permission response and
448            // does not adhere to the scheme laid out at
449            // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
450
451            #[derive(serde::Deserialize)]
452            #[serde(rename_all = "lowercase")]
453            enum RepoPermission {
454                Admin,
455                Write,
456                Read,
457                #[serde(other)]
458                Other,
459            }
460
461            #[derive(serde::Deserialize)]
462            struct RepositoryPermissions {
463                permission: RepoPermission,
464            }
465
466            self.client
467                .get::<RepositoryPermissions, _, _>(
468                    format!(
469                        "/repos/{ORG}/{REPO}/collaborators/{user}/permission",
470                        user = login.as_str()
471                    ),
472                    None::<&()>,
473                )
474                .await
475                .map(|response| {
476                    matches!(
477                        response.permission,
478                        RepoPermission::Write | RepoPermission::Admin
479                    )
480                })
481                .map_err(Into::into)
482        }
483
484        async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()> {
485            self.client
486                .issues(ORG, REPO)
487                .add_labels(issue_number, &[label.to_owned()])
488                .await
489                .map(|_| ())
490                .map_err(Into::into)
491        }
492    }
493}
494
495#[cfg(feature = "octo-client")]
496pub use octo_client::OctocrabClient;