github.rs

  1use std::{borrow::Cow, collections::HashMap, fmt};
  2
  3use anyhow::Result;
  4use derive_more::Deref;
  5use serde::Deserialize;
  6
  7use crate::git::CommitSha;
  8
  9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
 10
 11#[derive(Debug, Clone)]
 12pub struct GithubUser {
 13    pub login: String,
 14}
 15
 16#[derive(Debug, Clone)]
 17pub struct PullRequestData {
 18    pub number: u64,
 19    pub user: Option<GithubUser>,
 20    pub merged_by: Option<GithubUser>,
 21    pub labels: Option<Vec<String>>,
 22}
 23
 24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 25pub enum ReviewState {
 26    Approved,
 27    Other,
 28}
 29
 30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 31pub enum AuthorAssociation {
 32    Owner,
 33    Member,
 34    Collaborator,
 35    Contributor,
 36    FirstTimeContributor,
 37    FirstTimer,
 38    Mannequin,
 39    None,
 40}
 41
 42impl AuthorAssociation {
 43    pub fn has_write_access(&self) -> bool {
 44        matches!(self, Self::Owner | Self::Member | Self::Collaborator)
 45    }
 46}
 47
 48pub trait Approvable {
 49    fn author_login(&self) -> Option<&str>;
 50    fn review_state(&self) -> Option<ReviewState>;
 51    fn body(&self) -> Option<&str>;
 52    fn author_association(&self) -> Option<AuthorAssociation>;
 53}
 54
 55#[derive(Debug, Clone)]
 56pub struct PullRequestReview {
 57    pub user: Option<GithubUser>,
 58    pub state: Option<ReviewState>,
 59    pub body: Option<String>,
 60    pub author_association: Option<AuthorAssociation>,
 61}
 62
 63impl PullRequestReview {
 64    pub fn with_body(self, body: impl ToString) -> Self {
 65        Self {
 66            body: Some(body.to_string()),
 67            ..self
 68        }
 69    }
 70}
 71
 72impl Approvable for PullRequestReview {
 73    fn author_login(&self) -> Option<&str> {
 74        self.user.as_ref().map(|user| user.login.as_str())
 75    }
 76
 77    fn review_state(&self) -> Option<ReviewState> {
 78        self.state
 79    }
 80
 81    fn body(&self) -> Option<&str> {
 82        self.body.as_deref()
 83    }
 84
 85    fn author_association(&self) -> Option<AuthorAssociation> {
 86        self.author_association
 87    }
 88}
 89
 90#[derive(Debug, Clone)]
 91pub struct PullRequestComment {
 92    pub user: GithubUser,
 93    pub body: Option<String>,
 94    pub author_association: Option<AuthorAssociation>,
 95}
 96
 97impl Approvable for PullRequestComment {
 98    fn author_login(&self) -> Option<&str> {
 99        Some(&self.user.login)
100    }
101
102    fn review_state(&self) -> Option<ReviewState> {
103        None
104    }
105
106    fn body(&self) -> Option<&str> {
107        self.body.as_deref()
108    }
109
110    fn author_association(&self) -> Option<AuthorAssociation> {
111        self.author_association
112    }
113}
114
115#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
116pub struct GithubLogin {
117    login: String,
118}
119
120impl GithubLogin {
121    pub fn new(login: String) -> Self {
122        Self { login }
123    }
124}
125
126impl fmt::Display for GithubLogin {
127    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
128        write!(formatter, "@{}", self.login)
129    }
130}
131
132#[derive(Debug, Deserialize, Clone)]
133pub struct CommitAuthor {
134    name: String,
135    email: String,
136    user: Option<GithubLogin>,
137}
138
139impl CommitAuthor {
140    pub(crate) fn user(&self) -> Option<&GithubLogin> {
141        self.user.as_ref()
142    }
143}
144
145impl PartialEq for CommitAuthor {
146    fn eq(&self, other: &Self) -> bool {
147        self.user.as_ref().zip(other.user.as_ref()).map_or_else(
148            || self.email == other.email || self.name == other.name,
149            |(l, r)| l == r,
150        )
151    }
152}
153
154impl fmt::Display for CommitAuthor {
155    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
156        match self.user.as_ref() {
157            Some(user) => write!(formatter, "{} ({user})", self.name),
158            None => write!(formatter, "{} ({})", self.name, self.email),
159        }
160    }
161}
162
163#[derive(Debug, Deserialize, Clone)]
164pub struct CommitSignature {
165    #[serde(rename = "isValid")]
166    is_valid: bool,
167    signer: Option<GithubLogin>,
168}
169
170impl CommitSignature {
171    pub fn is_valid(&self) -> bool {
172        self.is_valid
173    }
174
175    pub fn signer(&self) -> Option<&GithubLogin> {
176        self.signer.as_ref()
177    }
178}
179
180#[derive(Debug, Clone, Deserialize)]
181pub struct CommitFileChange {
182    pub filename: String,
183}
184
185#[derive(Debug, Deserialize)]
186pub struct CommitMetadata {
187    #[serde(rename = "author")]
188    primary_author: CommitAuthor,
189    #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
190    co_authors: Vec<CommitAuthor>,
191    #[serde(default)]
192    signature: Option<CommitSignature>,
193    #[serde(default)]
194    additions: u64,
195    #[serde(default)]
196    deletions: u64,
197}
198
199impl CommitMetadata {
200    pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
201        let mut co_authors = self
202            .co_authors
203            .iter()
204            .filter(|co_author| *co_author != &self.primary_author)
205            .peekable();
206
207        co_authors.peek().is_some().then_some(co_authors)
208    }
209
210    pub fn primary_author(&self) -> &CommitAuthor {
211        &self.primary_author
212    }
213
214    pub fn signature(&self) -> Option<&CommitSignature> {
215        self.signature.as_ref()
216    }
217
218    pub fn additions(&self) -> u64 {
219        self.additions
220    }
221
222    pub fn deletions(&self) -> u64 {
223        self.deletions
224    }
225}
226
227#[derive(Debug, Deref)]
228pub struct CommitMetadataBySha(HashMap<CommitSha, CommitMetadata>);
229
230impl CommitMetadataBySha {
231    const SHA_PREFIX: &'static str = "commit";
232}
233
234impl<'de> serde::Deserialize<'de> for CommitMetadataBySha {
235    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
236    where
237        D: serde::Deserializer<'de>,
238    {
239        let raw = HashMap::<String, CommitMetadata>::deserialize(deserializer)?;
240        let map = raw
241            .into_iter()
242            .map(|(key, value)| {
243                let sha = key
244                    .strip_prefix(CommitMetadataBySha::SHA_PREFIX)
245                    .unwrap_or(&key);
246                (CommitSha::new(sha.to_owned()), value)
247            })
248            .collect();
249        Ok(Self(map))
250    }
251}
252
253#[derive(Clone)]
254pub struct Repository<'a> {
255    owner: Cow<'a, str>,
256    name: Cow<'a, str>,
257}
258
259impl<'a> Repository<'a> {
260    pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
261
262    pub fn new(owner: &'a str, name: &'a str) -> Self {
263        Self {
264            owner: Cow::Borrowed(owner),
265            name: Cow::Borrowed(name),
266        }
267    }
268
269    pub fn owner(&self) -> &str {
270        &self.owner
271    }
272
273    pub fn name(&self) -> &str {
274        &self.name
275    }
276}
277
278impl Repository<'static> {
279    pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
280        Self {
281            owner: Cow::Borrowed(owner),
282            name: Cow::Borrowed(name),
283        }
284    }
285}
286
287#[async_trait::async_trait(?Send)]
288pub trait GithubApiClient {
289    async fn get_pull_request(
290        &self,
291        repo: &Repository<'_>,
292        pr_number: u64,
293    ) -> Result<PullRequestData>;
294    async fn get_pull_request_reviews(
295        &self,
296        repo: &Repository<'_>,
297        pr_number: u64,
298    ) -> Result<Vec<PullRequestReview>>;
299    async fn get_pull_request_comments(
300        &self,
301        repo: &Repository<'_>,
302        pr_number: u64,
303    ) -> Result<Vec<PullRequestComment>>;
304    async fn get_commit_metadata(
305        &self,
306        repo: &Repository<'_>,
307        commit_shas: &[&CommitSha],
308    ) -> Result<CommitMetadataBySha>;
309    async fn get_commit_files(
310        &self,
311        repo: &Repository<'_>,
312        sha: &CommitSha,
313    ) -> Result<Vec<CommitFileChange>>;
314    async fn check_repo_write_permission(
315        &self,
316        repo: &Repository<'_>,
317        login: &GithubLogin,
318    ) -> Result<bool>;
319    async fn add_label_to_issue(
320        &self,
321        repo: &Repository<'_>,
322        label: &str,
323        issue_number: u64,
324    ) -> Result<()>;
325}
326
327pub mod graph_ql {
328    use anyhow::{Context as _, Result};
329    use itertools::Itertools as _;
330    use serde::Deserialize;
331
332    use crate::git::CommitSha;
333
334    use super::CommitMetadataBySha;
335
336    #[derive(Debug, Deserialize)]
337    pub struct GraphQLResponse<T> {
338        pub data: Option<T>,
339        pub errors: Option<Vec<GraphQLError>>,
340    }
341
342    impl<T> GraphQLResponse<T> {
343        pub fn into_data(self) -> Result<T> {
344            if let Some(errors) = &self.errors {
345                if !errors.is_empty() {
346                    let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
347                    anyhow::bail!("GraphQL error: {messages}");
348                }
349            }
350
351            self.data.context("GraphQL response contained no data")
352        }
353    }
354
355    #[derive(Debug, Deserialize)]
356    pub struct GraphQLError {
357        pub message: String,
358    }
359
360    #[derive(Debug, Deserialize)]
361    pub struct CommitMetadataResponse {
362        pub repository: CommitMetadataBySha,
363    }
364
365    pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
366    where
367        T: Deserialize<'de>,
368        D: serde::Deserializer<'de>,
369    {
370        #[derive(Deserialize)]
371        struct Nodes<T> {
372            nodes: Vec<T>,
373        }
374        Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
375    }
376
377    pub fn build_commit_metadata_query<'a>(
378        org: &str,
379        repo: &str,
380        shas: impl IntoIterator<Item = &'a CommitSha>,
381    ) -> String {
382        const FRAGMENT: &str = r#"
383            ... on Commit {
384                author {
385                    name
386                    email
387                    user { login }
388                }
389                authors(first: 10) {
390                    nodes {
391                        name
392                        email
393                        user { login }
394                    }
395                }
396                signature {
397                    isValid
398                    signer { login }
399                }
400                additions
401                deletions
402            }
403        "#;
404
405        let objects = shas
406            .into_iter()
407            .map(|commit_sha| {
408                format!(
409                    "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
410                    sha_prefix = CommitMetadataBySha::SHA_PREFIX,
411                    sha = **commit_sha,
412                )
413            })
414            .join("\n");
415
416        format!("{{  repository(owner: \"{org}\", name: \"{repo}\") {{ {objects}  }} }}")
417            .replace("\n", "")
418    }
419}
420
421#[cfg(feature = "octo-client")]
422mod octo_client {
423    use anyhow::{Context, Result};
424    use futures::TryStreamExt as _;
425    use jsonwebtoken::EncodingKey;
426    use octocrab::{
427        Octocrab, Page,
428        models::{
429            AuthorAssociation as OctocrabAuthorAssociation,
430            pulls::ReviewState as OctocrabReviewState,
431        },
432        service::middleware::cache::mem::InMemoryCache,
433    };
434    use serde::de::DeserializeOwned;
435    use tokio::pin;
436
437    use crate::{
438        git::CommitSha,
439        github::{Repository, graph_ql},
440    };
441
442    use super::{
443        AuthorAssociation, CommitFileChange, CommitMetadataBySha, GithubApiClient, GithubLogin,
444        GithubUser, PullRequestComment, PullRequestData, PullRequestReview, ReviewState,
445    };
446
447    fn convert_author_association(association: OctocrabAuthorAssociation) -> AuthorAssociation {
448        match association {
449            OctocrabAuthorAssociation::Owner => AuthorAssociation::Owner,
450            OctocrabAuthorAssociation::Member => AuthorAssociation::Member,
451            OctocrabAuthorAssociation::Collaborator => AuthorAssociation::Collaborator,
452            OctocrabAuthorAssociation::Contributor => AuthorAssociation::Contributor,
453            OctocrabAuthorAssociation::FirstTimeContributor => {
454                AuthorAssociation::FirstTimeContributor
455            }
456            OctocrabAuthorAssociation::FirstTimer => AuthorAssociation::FirstTimer,
457            OctocrabAuthorAssociation::Mannequin => AuthorAssociation::Mannequin,
458            OctocrabAuthorAssociation::None => AuthorAssociation::None,
459            _ => AuthorAssociation::None,
460        }
461    }
462
463    const PAGE_SIZE: u8 = 100;
464
465    pub struct OctocrabClient {
466        client: Octocrab,
467    }
468
469    impl OctocrabClient {
470        pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
471            let octocrab = Octocrab::builder()
472                .cache(InMemoryCache::new())
473                .app(
474                    app_id.into(),
475                    EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
476                )
477                .build()?;
478
479            let installations = octocrab
480                .apps()
481                .installations()
482                .send()
483                .await
484                .context("Failed to fetch installations")?
485                .take_items();
486
487            let installation_id = installations
488                .into_iter()
489                .find(|installation| installation.account.login == org)
490                .context("Could not find Zed repository in installations")?
491                .id;
492
493            let client = octocrab.installation(installation_id)?;
494            Ok(Self { client })
495        }
496
497        async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
498            let response: serde_json::Value = self.client.graphql(query).await?;
499            let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
500                .context("Failed to parse GraphQL response envelope")?;
501            parsed.into_data()
502        }
503
504        async fn get_all<T: DeserializeOwned + 'static>(
505            &self,
506            page: Page<T>,
507        ) -> octocrab::Result<Vec<T>> {
508            self.get_filtered(page, |_| true).await
509        }
510
511        async fn get_filtered<T: DeserializeOwned + 'static>(
512            &self,
513            page: Page<T>,
514            predicate: impl Fn(&T) -> bool,
515        ) -> octocrab::Result<Vec<T>> {
516            let stream = page.into_stream(&self.client);
517            pin!(stream);
518
519            let mut results = Vec::new();
520
521            while let Some(item) = stream.try_next().await?
522                && predicate(&item)
523            {
524                results.push(item);
525            }
526
527            Ok(results)
528        }
529    }
530
531    #[async_trait::async_trait(?Send)]
532    impl GithubApiClient for OctocrabClient {
533        async fn get_pull_request(
534            &self,
535            repo: &Repository<'_>,
536            pr_number: u64,
537        ) -> Result<PullRequestData> {
538            let pr = self
539                .client
540                .pulls(repo.owner.as_ref(), repo.name.as_ref())
541                .get(pr_number)
542                .await?;
543            Ok(PullRequestData {
544                number: pr.number,
545                user: pr.user.map(|user| GithubUser { login: user.login }),
546                merged_by: pr.merged_by.map(|user| GithubUser { login: user.login }),
547                labels: pr
548                    .labels
549                    .map(|labels| labels.into_iter().map(|label| label.name).collect()),
550            })
551        }
552
553        async fn get_pull_request_reviews(
554            &self,
555            repo: &Repository<'_>,
556            pr_number: u64,
557        ) -> Result<Vec<PullRequestReview>> {
558            let page = self
559                .client
560                .pulls(repo.owner.as_ref(), repo.name.as_ref())
561                .list_reviews(pr_number)
562                .per_page(PAGE_SIZE)
563                .send()
564                .await?;
565
566            let reviews = self.get_all(page).await?;
567
568            Ok(reviews
569                .into_iter()
570                .map(|review| PullRequestReview {
571                    user: review.user.map(|user| GithubUser { login: user.login }),
572                    state: review.state.map(|state| match state {
573                        OctocrabReviewState::Approved => ReviewState::Approved,
574                        _ => ReviewState::Other,
575                    }),
576                    body: review.body,
577                    author_association: review.author_association.map(convert_author_association),
578                })
579                .collect())
580        }
581
582        async fn get_pull_request_comments(
583            &self,
584            repo: &Repository<'_>,
585            pr_number: u64,
586        ) -> Result<Vec<PullRequestComment>> {
587            let page = self
588                .client
589                .issues(repo.owner.as_ref(), repo.name.as_ref())
590                .list_comments(pr_number)
591                .per_page(PAGE_SIZE)
592                .send()
593                .await?;
594
595            let comments = self.get_all(page).await?;
596
597            Ok(comments
598                .into_iter()
599                .map(|comment| PullRequestComment {
600                    user: GithubUser {
601                        login: comment.user.login,
602                    },
603                    body: comment.body,
604                    author_association: comment.author_association.map(convert_author_association),
605                })
606                .collect())
607        }
608
609        async fn get_commit_metadata(
610            &self,
611            repo: &Repository<'_>,
612            commit_shas: &[&CommitSha],
613        ) -> Result<CommitMetadataBySha> {
614            let query = graph_ql::build_commit_metadata_query(
615                repo.owner.as_ref(),
616                repo.name.as_ref(),
617                commit_shas.iter().copied(),
618            );
619            let query = serde_json::json!({ "query": query });
620            self.graphql::<graph_ql::CommitMetadataResponse>(&query)
621                .await
622                .map(|response| response.repository)
623        }
624
625        async fn get_commit_files(
626            &self,
627            repo: &Repository<'_>,
628            sha: &CommitSha,
629        ) -> Result<Vec<CommitFileChange>> {
630            let response = self
631                .client
632                .commits(repo.owner.as_ref(), repo.name.as_ref())
633                .get(sha.as_str())
634                .await?;
635
636            Ok(response
637                .files
638                .into_iter()
639                .flatten()
640                .map(|file| CommitFileChange {
641                    filename: file.filename,
642                })
643                .collect())
644        }
645
646        async fn check_repo_write_permission(
647            &self,
648            repo: &Repository<'_>,
649            login: &GithubLogin,
650        ) -> Result<bool> {
651            // Check org membership first - we save ourselves a few request that way
652            let page = self
653                .client
654                .orgs(repo.owner.as_ref())
655                .list_members()
656                .per_page(PAGE_SIZE)
657                .send()
658                .await?;
659
660            let members = self.get_all(page).await?;
661
662            if members
663                .into_iter()
664                .any(|member| member.login == login.as_str())
665            {
666                return Ok(true);
667            }
668
669            // TODO: octocrab fails to deserialize the permission response and
670            // does not adhere to the scheme laid out at
671            // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
672
673            #[derive(serde::Deserialize)]
674            #[serde(rename_all = "lowercase")]
675            enum RepoPermission {
676                Admin,
677                Write,
678                Read,
679                #[serde(other)]
680                Other,
681            }
682
683            #[derive(serde::Deserialize)]
684            struct RepositoryPermissions {
685                permission: RepoPermission,
686            }
687
688            self.client
689                .get::<RepositoryPermissions, _, _>(
690                    format!(
691                        "/repos/{owner}/{repo}/collaborators/{user}/permission",
692                        owner = repo.owner.as_ref(),
693                        repo = repo.name.as_ref(),
694                        user = login.as_str()
695                    ),
696                    None::<&()>,
697                )
698                .await
699                .map(|response| {
700                    matches!(
701                        response.permission,
702                        RepoPermission::Write | RepoPermission::Admin
703                    )
704                })
705                .map_err(Into::into)
706        }
707
708        async fn add_label_to_issue(
709            &self,
710            repo: &Repository<'_>,
711            label: &str,
712            issue_number: u64,
713        ) -> Result<()> {
714            self.client
715                .issues(repo.owner.as_ref(), repo.name.as_ref())
716                .add_labels(issue_number, &[label.to_owned()])
717                .await
718                .map(|_| ())
719                .map_err(Into::into)
720        }
721    }
722}
723
724#[cfg(feature = "octo-client")]
725pub use octo_client::OctocrabClient;