github.rs

  1use std::{borrow::Cow, collections::HashMap, fmt, ops::Not, rc::Rc, sync::LazyLock};
  2
  3use anyhow::Result;
  4use derive_more::Deref;
  5use serde::Deserialize;
  6
  7use crate::git::CommitSha;
  8
  9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
 10
 11#[derive(Debug, Clone)]
 12pub struct GithubUser {
 13    pub login: String,
 14}
 15
 16#[derive(Debug, Clone)]
 17pub struct PullRequestData {
 18    pub number: u64,
 19    pub user: Option<GithubUser>,
 20    pub merged_by: Option<GithubUser>,
 21    pub labels: Option<Vec<String>>,
 22}
 23
 24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 25pub enum ReviewState {
 26    Approved,
 27    Other,
 28}
 29
 30#[derive(Debug, Clone)]
 31pub struct PullRequestReview {
 32    pub user: Option<GithubUser>,
 33    pub state: Option<ReviewState>,
 34    pub body: Option<String>,
 35}
 36
 37impl PullRequestReview {
 38    pub fn with_body(self, body: impl ToString) -> Self {
 39        Self {
 40            body: Some(body.to_string()),
 41            ..self
 42        }
 43    }
 44}
 45
 46#[derive(Debug, Clone)]
 47pub struct PullRequestComment {
 48    pub user: GithubUser,
 49    pub body: Option<String>,
 50}
 51
 52#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
 53pub struct GithubLogin {
 54    login: String,
 55}
 56
 57impl GithubLogin {
 58    pub fn new(login: String) -> Self {
 59        Self { login }
 60    }
 61}
 62
 63impl fmt::Display for GithubLogin {
 64    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 65        write!(formatter, "@{}", self.login)
 66    }
 67}
 68
 69#[derive(Debug, Deserialize, Clone)]
 70pub struct CommitAuthor {
 71    name: String,
 72    email: String,
 73    user: Option<GithubLogin>,
 74}
 75
 76pub(crate) static ZED_ZIPPY_AUTHOR: LazyLock<CommitAuthor> = LazyLock::new(|| CommitAuthor {
 77    name: "Zed Zippy".to_string(),
 78    email: "234243425+zed-zippy[bot]@users.noreply.github.com".to_string(),
 79    user: Some(GithubLogin {
 80        login: "zed-zippy[bot]".to_string(),
 81    }),
 82});
 83
 84impl CommitAuthor {
 85    pub(crate) fn user(&self) -> Option<&GithubLogin> {
 86        self.user.as_ref()
 87    }
 88}
 89
 90impl PartialEq for CommitAuthor {
 91    fn eq(&self, other: &Self) -> bool {
 92        self.user.as_ref().zip(other.user.as_ref()).map_or_else(
 93            || self.email == other.email || self.name == other.name,
 94            |(l, r)| l == r,
 95        )
 96    }
 97}
 98
 99impl fmt::Display for CommitAuthor {
100    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
101        match self.user.as_ref() {
102            Some(user) => write!(formatter, "{} ({user})", self.name),
103            None => write!(formatter, "{} ({})", self.name, self.email),
104        }
105    }
106}
107
108#[derive(Debug, Deserialize)]
109pub struct CommitAuthors {
110    #[serde(rename = "author")]
111    primary_author: CommitAuthor,
112    #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
113    co_authors: Vec<CommitAuthor>,
114}
115
116impl CommitAuthors {
117    pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
118        self.co_authors.is_empty().not().then(|| {
119            self.co_authors
120                .iter()
121                .filter(|co_author| *co_author != &self.primary_author)
122        })
123    }
124}
125
126#[derive(Debug, Deref)]
127pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
128
129impl AuthorsForCommits {
130    const SHA_PREFIX: &'static str = "commit";
131}
132
133impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
134    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
135    where
136        D: serde::Deserializer<'de>,
137    {
138        let raw = HashMap::<String, CommitAuthors>::deserialize(deserializer)?;
139        let map = raw
140            .into_iter()
141            .map(|(key, value)| {
142                let sha = key
143                    .strip_prefix(AuthorsForCommits::SHA_PREFIX)
144                    .unwrap_or(&key);
145                (CommitSha::new(sha.to_owned()), value)
146            })
147            .collect();
148        Ok(Self(map))
149    }
150}
151
152#[derive(Clone)]
153pub struct Repository<'a> {
154    owner: Cow<'a, str>,
155    name: Cow<'a, str>,
156}
157
158impl<'a> Repository<'a> {
159    pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
160
161    pub fn new(owner: &'a str, name: &'a str) -> Self {
162        Self {
163            owner: Cow::Borrowed(owner),
164            name: Cow::Borrowed(name),
165        }
166    }
167
168    pub fn owner(&self) -> &str {
169        &self.owner
170    }
171
172    pub fn name(&self) -> &str {
173        &self.name
174    }
175}
176
177impl Repository<'static> {
178    pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
179        Self {
180            owner: Cow::Borrowed(owner),
181            name: Cow::Borrowed(name),
182        }
183    }
184}
185
186#[async_trait::async_trait(?Send)]
187pub trait GithubApiClient {
188    async fn get_pull_request(
189        &self,
190        repo: &Repository<'_>,
191        pr_number: u64,
192    ) -> Result<PullRequestData>;
193    async fn get_pull_request_reviews(
194        &self,
195        repo: &Repository<'_>,
196        pr_number: u64,
197    ) -> Result<Vec<PullRequestReview>>;
198    async fn get_pull_request_comments(
199        &self,
200        repo: &Repository<'_>,
201        pr_number: u64,
202    ) -> Result<Vec<PullRequestComment>>;
203    async fn get_commit_authors(
204        &self,
205        repo: &Repository<'_>,
206        commit_shas: &[&CommitSha],
207    ) -> Result<AuthorsForCommits>;
208    async fn check_repo_write_permission(
209        &self,
210        repo: &Repository<'_>,
211        login: &GithubLogin,
212    ) -> Result<bool>;
213    async fn add_label_to_issue(
214        &self,
215        repo: &Repository<'_>,
216        label: &str,
217        issue_number: u64,
218    ) -> Result<()>;
219}
220
221#[derive(Deref)]
222pub struct GithubClient {
223    api: Rc<dyn GithubApiClient>,
224}
225
226impl GithubClient {
227    pub fn new(api: Rc<dyn GithubApiClient>) -> Self {
228        Self { api }
229    }
230
231    #[cfg(feature = "octo-client")]
232    pub async fn for_app_in_repo(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
233        let client = OctocrabClient::new(app_id, app_private_key, org).await?;
234        Ok(Self::new(Rc::new(client)))
235    }
236}
237
238pub mod graph_ql {
239    use anyhow::{Context as _, Result};
240    use itertools::Itertools as _;
241    use serde::Deserialize;
242
243    use crate::git::CommitSha;
244
245    use super::AuthorsForCommits;
246
247    #[derive(Debug, Deserialize)]
248    pub struct GraphQLResponse<T> {
249        pub data: Option<T>,
250        pub errors: Option<Vec<GraphQLError>>,
251    }
252
253    impl<T> GraphQLResponse<T> {
254        pub fn into_data(self) -> Result<T> {
255            if let Some(errors) = &self.errors {
256                if !errors.is_empty() {
257                    let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
258                    anyhow::bail!("GraphQL error: {messages}");
259                }
260            }
261
262            self.data.context("GraphQL response contained no data")
263        }
264    }
265
266    #[derive(Debug, Deserialize)]
267    pub struct GraphQLError {
268        pub message: String,
269    }
270
271    #[derive(Debug, Deserialize)]
272    pub struct CommitAuthorsResponse {
273        pub repository: AuthorsForCommits,
274    }
275
276    pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
277    where
278        T: Deserialize<'de>,
279        D: serde::Deserializer<'de>,
280    {
281        #[derive(Deserialize)]
282        struct Nodes<T> {
283            nodes: Vec<T>,
284        }
285        Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
286    }
287
288    pub fn build_co_authors_query<'a>(
289        org: &str,
290        repo: &str,
291        shas: impl IntoIterator<Item = &'a CommitSha>,
292    ) -> String {
293        const FRAGMENT: &str = r#"
294            ... on Commit {
295                author {
296                    name
297                    email
298                    user { login }
299                }
300                authors(first: 10) {
301                    nodes {
302                        name
303                        email
304                        user { login }
305                    }
306                }
307            }
308        "#;
309
310        let objects = shas
311            .into_iter()
312            .map(|commit_sha| {
313                format!(
314                    "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
315                    sha_prefix = AuthorsForCommits::SHA_PREFIX,
316                    sha = **commit_sha,
317                )
318            })
319            .join("\n");
320
321        format!("{{  repository(owner: \"{org}\", name: \"{repo}\") {{ {objects}  }} }}")
322            .replace("\n", "")
323    }
324}
325
326#[cfg(feature = "octo-client")]
327mod octo_client {
328    use anyhow::{Context, Result};
329    use futures::TryStreamExt as _;
330    use jsonwebtoken::EncodingKey;
331    use octocrab::{
332        Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
333        service::middleware::cache::mem::InMemoryCache,
334    };
335    use serde::de::DeserializeOwned;
336    use tokio::pin;
337
338    use crate::{
339        git::CommitSha,
340        github::{Repository, graph_ql},
341    };
342
343    use super::{
344        AuthorsForCommits, GithubApiClient, GithubLogin, GithubUser, PullRequestComment,
345        PullRequestData, PullRequestReview, ReviewState,
346    };
347
348    const PAGE_SIZE: u8 = 100;
349
350    pub struct OctocrabClient {
351        client: Octocrab,
352    }
353
354    impl OctocrabClient {
355        pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
356            let octocrab = Octocrab::builder()
357                .cache(InMemoryCache::new())
358                .app(
359                    app_id.into(),
360                    EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
361                )
362                .build()?;
363
364            let installations = octocrab
365                .apps()
366                .installations()
367                .send()
368                .await
369                .context("Failed to fetch installations")?
370                .take_items();
371
372            let installation_id = installations
373                .into_iter()
374                .find(|installation| installation.account.login == org)
375                .context("Could not find Zed repository in installations")?
376                .id;
377
378            let client = octocrab.installation(installation_id)?;
379            Ok(Self { client })
380        }
381
382        async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
383            let response: serde_json::Value = self.client.graphql(query).await?;
384            let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
385                .context("Failed to parse GraphQL response envelope")?;
386            parsed.into_data()
387        }
388
389        async fn get_all<T: DeserializeOwned + 'static>(
390            &self,
391            page: Page<T>,
392        ) -> octocrab::Result<Vec<T>> {
393            self.get_filtered(page, |_| true).await
394        }
395
396        async fn get_filtered<T: DeserializeOwned + 'static>(
397            &self,
398            page: Page<T>,
399            predicate: impl Fn(&T) -> bool,
400        ) -> octocrab::Result<Vec<T>> {
401            let stream = page.into_stream(&self.client);
402            pin!(stream);
403
404            let mut results = Vec::new();
405
406            while let Some(item) = stream.try_next().await?
407                && predicate(&item)
408            {
409                results.push(item);
410            }
411
412            Ok(results)
413        }
414    }
415
416    #[async_trait::async_trait(?Send)]
417    impl GithubApiClient for OctocrabClient {
418        async fn get_pull_request(
419            &self,
420            repo: &Repository<'_>,
421            pr_number: u64,
422        ) -> Result<PullRequestData> {
423            let pr = self
424                .client
425                .pulls(repo.owner.as_ref(), repo.name.as_ref())
426                .get(pr_number)
427                .await?;
428            Ok(PullRequestData {
429                number: pr.number,
430                user: pr.user.map(|user| GithubUser { login: user.login }),
431                merged_by: pr.merged_by.map(|user| GithubUser { login: user.login }),
432                labels: pr
433                    .labels
434                    .map(|labels| labels.into_iter().map(|label| label.name).collect()),
435            })
436        }
437
438        async fn get_pull_request_reviews(
439            &self,
440            repo: &Repository<'_>,
441            pr_number: u64,
442        ) -> Result<Vec<PullRequestReview>> {
443            let page = self
444                .client
445                .pulls(repo.owner.as_ref(), repo.name.as_ref())
446                .list_reviews(pr_number)
447                .per_page(PAGE_SIZE)
448                .send()
449                .await?;
450
451            let reviews = self.get_all(page).await?;
452
453            Ok(reviews
454                .into_iter()
455                .map(|review| PullRequestReview {
456                    user: review.user.map(|user| GithubUser { login: user.login }),
457                    state: review.state.map(|state| match state {
458                        OctocrabReviewState::Approved => ReviewState::Approved,
459                        _ => ReviewState::Other,
460                    }),
461                    body: review.body,
462                })
463                .collect())
464        }
465
466        async fn get_pull_request_comments(
467            &self,
468            repo: &Repository<'_>,
469            pr_number: u64,
470        ) -> Result<Vec<PullRequestComment>> {
471            let page = self
472                .client
473                .issues(repo.owner.as_ref(), repo.name.as_ref())
474                .list_comments(pr_number)
475                .per_page(PAGE_SIZE)
476                .send()
477                .await?;
478
479            let comments = self.get_all(page).await?;
480
481            Ok(comments
482                .into_iter()
483                .map(|comment| PullRequestComment {
484                    user: GithubUser {
485                        login: comment.user.login,
486                    },
487                    body: comment.body,
488                })
489                .collect())
490        }
491
492        async fn get_commit_authors(
493            &self,
494            repo: &Repository<'_>,
495            commit_shas: &[&CommitSha],
496        ) -> Result<AuthorsForCommits> {
497            let query = graph_ql::build_co_authors_query(
498                repo.owner.as_ref(),
499                repo.name.as_ref(),
500                commit_shas.iter().copied(),
501            );
502            let query = serde_json::json!({ "query": query });
503            self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
504                .await
505                .map(|response| response.repository)
506        }
507
508        async fn check_repo_write_permission(
509            &self,
510            repo: &Repository<'_>,
511            login: &GithubLogin,
512        ) -> Result<bool> {
513            // Check org membership first - we save ourselves a few request that way
514            let page = self
515                .client
516                .orgs(repo.owner.as_ref())
517                .list_members()
518                .per_page(PAGE_SIZE)
519                .send()
520                .await?;
521
522            let members = self.get_all(page).await?;
523
524            if members
525                .into_iter()
526                .any(|member| member.login == login.as_str())
527            {
528                return Ok(true);
529            }
530
531            // TODO: octocrab fails to deserialize the permission response and
532            // does not adhere to the scheme laid out at
533            // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
534
535            #[derive(serde::Deserialize)]
536            #[serde(rename_all = "lowercase")]
537            enum RepoPermission {
538                Admin,
539                Write,
540                Read,
541                #[serde(other)]
542                Other,
543            }
544
545            #[derive(serde::Deserialize)]
546            struct RepositoryPermissions {
547                permission: RepoPermission,
548            }
549
550            self.client
551                .get::<RepositoryPermissions, _, _>(
552                    format!(
553                        "/repos/{owner}/{repo}/collaborators/{user}/permission",
554                        owner = repo.owner.as_ref(),
555                        repo = repo.name.as_ref(),
556                        user = login.as_str()
557                    ),
558                    None::<&()>,
559                )
560                .await
561                .map(|response| {
562                    matches!(
563                        response.permission,
564                        RepoPermission::Write | RepoPermission::Admin
565                    )
566                })
567                .map_err(Into::into)
568        }
569
570        async fn add_label_to_issue(
571            &self,
572            repo: &Repository<'_>,
573            label: &str,
574            issue_number: u64,
575        ) -> Result<()> {
576            self.client
577                .issues(repo.owner.as_ref(), repo.name.as_ref())
578                .add_labels(issue_number, &[label.to_owned()])
579                .await
580                .map(|_| ())
581                .map_err(Into::into)
582        }
583    }
584}
585
586#[cfg(feature = "octo-client")]
587pub use octo_client::OctocrabClient;