github.rs

  1use std::{borrow::Cow, collections::HashMap, fmt};
  2
  3use anyhow::Result;
  4use derive_more::Deref;
  5use serde::Deserialize;
  6
  7use crate::git::CommitSha;
  8
  9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
 10
 11#[derive(Debug, Clone)]
 12pub struct GithubUser {
 13    pub login: String,
 14}
 15
 16#[derive(Debug, Clone)]
 17pub struct PullRequestData {
 18    pub number: u64,
 19    pub user: Option<GithubUser>,
 20    pub merged_by: Option<GithubUser>,
 21    pub labels: Option<Vec<String>>,
 22}
 23
 24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 25pub enum ReviewState {
 26    Approved,
 27    Other,
 28}
 29
 30#[derive(Debug, Clone)]
 31pub struct PullRequestReview {
 32    pub user: Option<GithubUser>,
 33    pub state: Option<ReviewState>,
 34    pub body: Option<String>,
 35}
 36
 37impl PullRequestReview {
 38    pub fn with_body(self, body: impl ToString) -> Self {
 39        Self {
 40            body: Some(body.to_string()),
 41            ..self
 42        }
 43    }
 44}
 45
 46#[derive(Debug, Clone)]
 47pub struct PullRequestComment {
 48    pub user: GithubUser,
 49    pub body: Option<String>,
 50}
 51
 52#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
 53pub struct GithubLogin {
 54    login: String,
 55}
 56
 57impl GithubLogin {
 58    pub fn new(login: String) -> Self {
 59        Self { login }
 60    }
 61}
 62
 63impl fmt::Display for GithubLogin {
 64    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 65        write!(formatter, "@{}", self.login)
 66    }
 67}
 68
 69#[derive(Debug, Deserialize, Clone)]
 70pub struct CommitAuthor {
 71    name: String,
 72    email: String,
 73    user: Option<GithubLogin>,
 74}
 75
 76impl CommitAuthor {
 77    pub(crate) fn user(&self) -> Option<&GithubLogin> {
 78        self.user.as_ref()
 79    }
 80}
 81
 82impl PartialEq for CommitAuthor {
 83    fn eq(&self, other: &Self) -> bool {
 84        self.user.as_ref().zip(other.user.as_ref()).map_or_else(
 85            || self.email == other.email || self.name == other.name,
 86            |(l, r)| l == r,
 87        )
 88    }
 89}
 90
 91impl fmt::Display for CommitAuthor {
 92    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 93        match self.user.as_ref() {
 94            Some(user) => write!(formatter, "{} ({user})", self.name),
 95            None => write!(formatter, "{} ({})", self.name, self.email),
 96        }
 97    }
 98}
 99
100#[derive(Debug, Deserialize, Clone)]
101pub struct CommitSignature {
102    #[serde(rename = "isValid")]
103    is_valid: bool,
104    signer: Option<GithubLogin>,
105}
106
107impl CommitSignature {
108    pub fn is_valid(&self) -> bool {
109        self.is_valid
110    }
111
112    pub fn signer(&self) -> Option<&GithubLogin> {
113        self.signer.as_ref()
114    }
115}
116
117#[derive(Debug, Deserialize)]
118pub struct CommitMetadata {
119    #[serde(rename = "author")]
120    primary_author: CommitAuthor,
121    #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
122    co_authors: Vec<CommitAuthor>,
123    #[serde(default)]
124    signature: Option<CommitSignature>,
125    #[serde(default)]
126    additions: u64,
127    #[serde(default)]
128    deletions: u64,
129}
130
131impl CommitMetadata {
132    pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
133        let mut co_authors = self
134            .co_authors
135            .iter()
136            .filter(|co_author| *co_author != &self.primary_author)
137            .peekable();
138
139        co_authors.peek().is_some().then_some(co_authors)
140    }
141
142    pub fn primary_author(&self) -> &CommitAuthor {
143        &self.primary_author
144    }
145
146    pub fn signature(&self) -> Option<&CommitSignature> {
147        self.signature.as_ref()
148    }
149
150    pub fn additions(&self) -> u64 {
151        self.additions
152    }
153
154    pub fn deletions(&self) -> u64 {
155        self.deletions
156    }
157}
158
159#[derive(Debug, Deref)]
160pub struct CommitMetadataBySha(HashMap<CommitSha, CommitMetadata>);
161
162impl CommitMetadataBySha {
163    const SHA_PREFIX: &'static str = "commit";
164}
165
166impl<'de> serde::Deserialize<'de> for CommitMetadataBySha {
167    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
168    where
169        D: serde::Deserializer<'de>,
170    {
171        let raw = HashMap::<String, CommitMetadata>::deserialize(deserializer)?;
172        let map = raw
173            .into_iter()
174            .map(|(key, value)| {
175                let sha = key
176                    .strip_prefix(CommitMetadataBySha::SHA_PREFIX)
177                    .unwrap_or(&key);
178                (CommitSha::new(sha.to_owned()), value)
179            })
180            .collect();
181        Ok(Self(map))
182    }
183}
184
185#[derive(Clone)]
186pub struct Repository<'a> {
187    owner: Cow<'a, str>,
188    name: Cow<'a, str>,
189}
190
191impl<'a> Repository<'a> {
192    pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
193
194    pub fn new(owner: &'a str, name: &'a str) -> Self {
195        Self {
196            owner: Cow::Borrowed(owner),
197            name: Cow::Borrowed(name),
198        }
199    }
200
201    pub fn owner(&self) -> &str {
202        &self.owner
203    }
204
205    pub fn name(&self) -> &str {
206        &self.name
207    }
208}
209
210impl Repository<'static> {
211    pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
212        Self {
213            owner: Cow::Borrowed(owner),
214            name: Cow::Borrowed(name),
215        }
216    }
217}
218
219#[async_trait::async_trait(?Send)]
220pub trait GithubApiClient {
221    async fn get_pull_request(
222        &self,
223        repo: &Repository<'_>,
224        pr_number: u64,
225    ) -> Result<PullRequestData>;
226    async fn get_pull_request_reviews(
227        &self,
228        repo: &Repository<'_>,
229        pr_number: u64,
230    ) -> Result<Vec<PullRequestReview>>;
231    async fn get_pull_request_comments(
232        &self,
233        repo: &Repository<'_>,
234        pr_number: u64,
235    ) -> Result<Vec<PullRequestComment>>;
236    async fn get_commit_metadata(
237        &self,
238        repo: &Repository<'_>,
239        commit_shas: &[&CommitSha],
240    ) -> Result<CommitMetadataBySha>;
241    async fn check_repo_write_permission(
242        &self,
243        repo: &Repository<'_>,
244        login: &GithubLogin,
245    ) -> Result<bool>;
246    async fn add_label_to_issue(
247        &self,
248        repo: &Repository<'_>,
249        label: &str,
250        issue_number: u64,
251    ) -> Result<()>;
252}
253
254pub mod graph_ql {
255    use anyhow::{Context as _, Result};
256    use itertools::Itertools as _;
257    use serde::Deserialize;
258
259    use crate::git::CommitSha;
260
261    use super::CommitMetadataBySha;
262
263    #[derive(Debug, Deserialize)]
264    pub struct GraphQLResponse<T> {
265        pub data: Option<T>,
266        pub errors: Option<Vec<GraphQLError>>,
267    }
268
269    impl<T> GraphQLResponse<T> {
270        pub fn into_data(self) -> Result<T> {
271            if let Some(errors) = &self.errors {
272                if !errors.is_empty() {
273                    let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
274                    anyhow::bail!("GraphQL error: {messages}");
275                }
276            }
277
278            self.data.context("GraphQL response contained no data")
279        }
280    }
281
282    #[derive(Debug, Deserialize)]
283    pub struct GraphQLError {
284        pub message: String,
285    }
286
287    #[derive(Debug, Deserialize)]
288    pub struct CommitMetadataResponse {
289        pub repository: CommitMetadataBySha,
290    }
291
292    pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
293    where
294        T: Deserialize<'de>,
295        D: serde::Deserializer<'de>,
296    {
297        #[derive(Deserialize)]
298        struct Nodes<T> {
299            nodes: Vec<T>,
300        }
301        Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
302    }
303
304    pub fn build_commit_metadata_query<'a>(
305        org: &str,
306        repo: &str,
307        shas: impl IntoIterator<Item = &'a CommitSha>,
308    ) -> String {
309        const FRAGMENT: &str = r#"
310            ... on Commit {
311                author {
312                    name
313                    email
314                    user { login }
315                }
316                authors(first: 10) {
317                    nodes {
318                        name
319                        email
320                        user { login }
321                    }
322                }
323                signature {
324                    isValid
325                    signer { login }
326                }
327                additions
328                deletions
329            }
330        "#;
331
332        let objects = shas
333            .into_iter()
334            .map(|commit_sha| {
335                format!(
336                    "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
337                    sha_prefix = CommitMetadataBySha::SHA_PREFIX,
338                    sha = **commit_sha,
339                )
340            })
341            .join("\n");
342
343        format!("{{  repository(owner: \"{org}\", name: \"{repo}\") {{ {objects}  }} }}")
344            .replace("\n", "")
345    }
346}
347
348#[cfg(feature = "octo-client")]
349mod octo_client {
350    use anyhow::{Context, Result};
351    use futures::TryStreamExt as _;
352    use jsonwebtoken::EncodingKey;
353    use octocrab::{
354        Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
355        service::middleware::cache::mem::InMemoryCache,
356    };
357    use serde::de::DeserializeOwned;
358    use tokio::pin;
359
360    use crate::{
361        git::CommitSha,
362        github::{Repository, graph_ql},
363    };
364
365    use super::{
366        CommitMetadataBySha, GithubApiClient, GithubLogin, GithubUser, PullRequestComment,
367        PullRequestData, PullRequestReview, ReviewState,
368    };
369
370    const PAGE_SIZE: u8 = 100;
371
372    pub struct OctocrabClient {
373        client: Octocrab,
374    }
375
376    impl OctocrabClient {
377        pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
378            let octocrab = Octocrab::builder()
379                .cache(InMemoryCache::new())
380                .app(
381                    app_id.into(),
382                    EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
383                )
384                .build()?;
385
386            let installations = octocrab
387                .apps()
388                .installations()
389                .send()
390                .await
391                .context("Failed to fetch installations")?
392                .take_items();
393
394            let installation_id = installations
395                .into_iter()
396                .find(|installation| installation.account.login == org)
397                .context("Could not find Zed repository in installations")?
398                .id;
399
400            let client = octocrab.installation(installation_id)?;
401            Ok(Self { client })
402        }
403
404        async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
405            let response: serde_json::Value = self.client.graphql(query).await?;
406            let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
407                .context("Failed to parse GraphQL response envelope")?;
408            parsed.into_data()
409        }
410
411        async fn get_all<T: DeserializeOwned + 'static>(
412            &self,
413            page: Page<T>,
414        ) -> octocrab::Result<Vec<T>> {
415            self.get_filtered(page, |_| true).await
416        }
417
418        async fn get_filtered<T: DeserializeOwned + 'static>(
419            &self,
420            page: Page<T>,
421            predicate: impl Fn(&T) -> bool,
422        ) -> octocrab::Result<Vec<T>> {
423            let stream = page.into_stream(&self.client);
424            pin!(stream);
425
426            let mut results = Vec::new();
427
428            while let Some(item) = stream.try_next().await?
429                && predicate(&item)
430            {
431                results.push(item);
432            }
433
434            Ok(results)
435        }
436    }
437
438    #[async_trait::async_trait(?Send)]
439    impl GithubApiClient for OctocrabClient {
440        async fn get_pull_request(
441            &self,
442            repo: &Repository<'_>,
443            pr_number: u64,
444        ) -> Result<PullRequestData> {
445            let pr = self
446                .client
447                .pulls(repo.owner.as_ref(), repo.name.as_ref())
448                .get(pr_number)
449                .await?;
450            Ok(PullRequestData {
451                number: pr.number,
452                user: pr.user.map(|user| GithubUser { login: user.login }),
453                merged_by: pr.merged_by.map(|user| GithubUser { login: user.login }),
454                labels: pr
455                    .labels
456                    .map(|labels| labels.into_iter().map(|label| label.name).collect()),
457            })
458        }
459
460        async fn get_pull_request_reviews(
461            &self,
462            repo: &Repository<'_>,
463            pr_number: u64,
464        ) -> Result<Vec<PullRequestReview>> {
465            let page = self
466                .client
467                .pulls(repo.owner.as_ref(), repo.name.as_ref())
468                .list_reviews(pr_number)
469                .per_page(PAGE_SIZE)
470                .send()
471                .await?;
472
473            let reviews = self.get_all(page).await?;
474
475            Ok(reviews
476                .into_iter()
477                .map(|review| PullRequestReview {
478                    user: review.user.map(|user| GithubUser { login: user.login }),
479                    state: review.state.map(|state| match state {
480                        OctocrabReviewState::Approved => ReviewState::Approved,
481                        _ => ReviewState::Other,
482                    }),
483                    body: review.body,
484                })
485                .collect())
486        }
487
488        async fn get_pull_request_comments(
489            &self,
490            repo: &Repository<'_>,
491            pr_number: u64,
492        ) -> Result<Vec<PullRequestComment>> {
493            let page = self
494                .client
495                .issues(repo.owner.as_ref(), repo.name.as_ref())
496                .list_comments(pr_number)
497                .per_page(PAGE_SIZE)
498                .send()
499                .await?;
500
501            let comments = self.get_all(page).await?;
502
503            Ok(comments
504                .into_iter()
505                .map(|comment| PullRequestComment {
506                    user: GithubUser {
507                        login: comment.user.login,
508                    },
509                    body: comment.body,
510                })
511                .collect())
512        }
513
514        async fn get_commit_metadata(
515            &self,
516            repo: &Repository<'_>,
517            commit_shas: &[&CommitSha],
518        ) -> Result<CommitMetadataBySha> {
519            let query = graph_ql::build_commit_metadata_query(
520                repo.owner.as_ref(),
521                repo.name.as_ref(),
522                commit_shas.iter().copied(),
523            );
524            let query = serde_json::json!({ "query": query });
525            self.graphql::<graph_ql::CommitMetadataResponse>(&query)
526                .await
527                .map(|response| response.repository)
528        }
529
530        async fn check_repo_write_permission(
531            &self,
532            repo: &Repository<'_>,
533            login: &GithubLogin,
534        ) -> Result<bool> {
535            // Check org membership first - we save ourselves a few request that way
536            let page = self
537                .client
538                .orgs(repo.owner.as_ref())
539                .list_members()
540                .per_page(PAGE_SIZE)
541                .send()
542                .await?;
543
544            let members = self.get_all(page).await?;
545
546            if members
547                .into_iter()
548                .any(|member| member.login == login.as_str())
549            {
550                return Ok(true);
551            }
552
553            // TODO: octocrab fails to deserialize the permission response and
554            // does not adhere to the scheme laid out at
555            // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
556
557            #[derive(serde::Deserialize)]
558            #[serde(rename_all = "lowercase")]
559            enum RepoPermission {
560                Admin,
561                Write,
562                Read,
563                #[serde(other)]
564                Other,
565            }
566
567            #[derive(serde::Deserialize)]
568            struct RepositoryPermissions {
569                permission: RepoPermission,
570            }
571
572            self.client
573                .get::<RepositoryPermissions, _, _>(
574                    format!(
575                        "/repos/{owner}/{repo}/collaborators/{user}/permission",
576                        owner = repo.owner.as_ref(),
577                        repo = repo.name.as_ref(),
578                        user = login.as_str()
579                    ),
580                    None::<&()>,
581                )
582                .await
583                .map(|response| {
584                    matches!(
585                        response.permission,
586                        RepoPermission::Write | RepoPermission::Admin
587                    )
588                })
589                .map_err(Into::into)
590        }
591
592        async fn add_label_to_issue(
593            &self,
594            repo: &Repository<'_>,
595            label: &str,
596            issue_number: u64,
597        ) -> Result<()> {
598            self.client
599                .issues(repo.owner.as_ref(), repo.name.as_ref())
600                .add_labels(issue_number, &[label.to_owned()])
601                .await
602                .map(|_| ())
603                .map_err(Into::into)
604        }
605    }
606}
607
608#[cfg(feature = "octo-client")]
609pub use octo_client::OctocrabClient;