github.rs

  1use std::{borrow::Cow, collections::HashMap, fmt, ops::Not, sync::LazyLock};
  2
  3use anyhow::Result;
  4use derive_more::Deref;
  5use serde::Deserialize;
  6
  7use crate::git::CommitSha;
  8
  9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
 10
 11#[derive(Debug, Clone)]
 12pub struct GithubUser {
 13    pub login: String,
 14}
 15
 16#[derive(Debug, Clone)]
 17pub struct PullRequestData {
 18    pub number: u64,
 19    pub user: Option<GithubUser>,
 20    pub merged_by: Option<GithubUser>,
 21    pub labels: Option<Vec<String>>,
 22}
 23
 24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 25pub enum ReviewState {
 26    Approved,
 27    Other,
 28}
 29
 30#[derive(Debug, Clone)]
 31pub struct PullRequestReview {
 32    pub user: Option<GithubUser>,
 33    pub state: Option<ReviewState>,
 34    pub body: Option<String>,
 35}
 36
 37impl PullRequestReview {
 38    pub fn with_body(self, body: impl ToString) -> Self {
 39        Self {
 40            body: Some(body.to_string()),
 41            ..self
 42        }
 43    }
 44}
 45
 46#[derive(Debug, Clone)]
 47pub struct PullRequestComment {
 48    pub user: GithubUser,
 49    pub body: Option<String>,
 50}
 51
 52#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
 53pub struct GithubLogin {
 54    login: String,
 55}
 56
 57impl GithubLogin {
 58    pub fn new(login: String) -> Self {
 59        Self { login }
 60    }
 61}
 62
 63impl fmt::Display for GithubLogin {
 64    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 65        write!(formatter, "@{}", self.login)
 66    }
 67}
 68
 69#[derive(Debug, Deserialize, Clone)]
 70pub struct CommitAuthor {
 71    name: String,
 72    email: String,
 73    user: Option<GithubLogin>,
 74}
 75
 76pub(crate) static ZED_ZIPPY_AUTHOR: LazyLock<CommitAuthor> = LazyLock::new(|| CommitAuthor {
 77    name: "Zed Zippy".to_string(),
 78    email: "234243425+zed-zippy[bot]@users.noreply.github.com".to_string(),
 79    user: Some(GithubLogin {
 80        login: "zed-zippy[bot]".to_string(),
 81    }),
 82});
 83
 84impl CommitAuthor {
 85    pub(crate) fn user(&self) -> Option<&GithubLogin> {
 86        self.user.as_ref()
 87    }
 88}
 89
 90impl PartialEq for CommitAuthor {
 91    fn eq(&self, other: &Self) -> bool {
 92        self.user.as_ref().zip(other.user.as_ref()).map_or_else(
 93            || self.email == other.email || self.name == other.name,
 94            |(l, r)| l == r,
 95        )
 96    }
 97}
 98
 99impl fmt::Display for CommitAuthor {
100    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
101        match self.user.as_ref() {
102            Some(user) => write!(formatter, "{} ({user})", self.name),
103            None => write!(formatter, "{} ({})", self.name, self.email),
104        }
105    }
106}
107
108#[derive(Debug, Deserialize)]
109pub struct CommitAuthors {
110    #[serde(rename = "author")]
111    primary_author: CommitAuthor,
112    #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
113    co_authors: Vec<CommitAuthor>,
114}
115
116impl CommitAuthors {
117    pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
118        self.co_authors.is_empty().not().then(|| {
119            self.co_authors
120                .iter()
121                .filter(|co_author| *co_author != &self.primary_author)
122        })
123    }
124}
125
126#[derive(Debug, Deref)]
127pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
128
129impl AuthorsForCommits {
130    const SHA_PREFIX: &'static str = "commit";
131}
132
133impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
134    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
135    where
136        D: serde::Deserializer<'de>,
137    {
138        let raw = HashMap::<String, CommitAuthors>::deserialize(deserializer)?;
139        let map = raw
140            .into_iter()
141            .map(|(key, value)| {
142                let sha = key
143                    .strip_prefix(AuthorsForCommits::SHA_PREFIX)
144                    .unwrap_or(&key);
145                (CommitSha::new(sha.to_owned()), value)
146            })
147            .collect();
148        Ok(Self(map))
149    }
150}
151
152#[derive(Clone)]
153pub struct Repository<'a> {
154    owner: Cow<'a, str>,
155    name: Cow<'a, str>,
156}
157
158impl<'a> Repository<'a> {
159    pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
160
161    pub fn new(owner: &'a str, name: &'a str) -> Self {
162        Self {
163            owner: Cow::Borrowed(owner),
164            name: Cow::Borrowed(name),
165        }
166    }
167
168    pub fn owner(&self) -> &str {
169        &self.owner
170    }
171
172    pub fn name(&self) -> &str {
173        &self.name
174    }
175}
176
177impl Repository<'static> {
178    pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
179        Self {
180            owner: Cow::Borrowed(owner),
181            name: Cow::Borrowed(name),
182        }
183    }
184}
185
186#[async_trait::async_trait(?Send)]
187pub trait GithubApiClient {
188    async fn get_pull_request(
189        &self,
190        repo: &Repository<'_>,
191        pr_number: u64,
192    ) -> Result<PullRequestData>;
193    async fn get_pull_request_reviews(
194        &self,
195        repo: &Repository<'_>,
196        pr_number: u64,
197    ) -> Result<Vec<PullRequestReview>>;
198    async fn get_pull_request_comments(
199        &self,
200        repo: &Repository<'_>,
201        pr_number: u64,
202    ) -> Result<Vec<PullRequestComment>>;
203    async fn get_commit_authors(
204        &self,
205        repo: &Repository<'_>,
206        commit_shas: &[&CommitSha],
207    ) -> Result<AuthorsForCommits>;
208    async fn check_repo_write_permission(
209        &self,
210        repo: &Repository<'_>,
211        login: &GithubLogin,
212    ) -> Result<bool>;
213    async fn add_label_to_issue(
214        &self,
215        repo: &Repository<'_>,
216        label: &str,
217        issue_number: u64,
218    ) -> Result<()>;
219}
220
221pub mod graph_ql {
222    use anyhow::{Context as _, Result};
223    use itertools::Itertools as _;
224    use serde::Deserialize;
225
226    use crate::git::CommitSha;
227
228    use super::AuthorsForCommits;
229
230    #[derive(Debug, Deserialize)]
231    pub struct GraphQLResponse<T> {
232        pub data: Option<T>,
233        pub errors: Option<Vec<GraphQLError>>,
234    }
235
236    impl<T> GraphQLResponse<T> {
237        pub fn into_data(self) -> Result<T> {
238            if let Some(errors) = &self.errors {
239                if !errors.is_empty() {
240                    let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
241                    anyhow::bail!("GraphQL error: {messages}");
242                }
243            }
244
245            self.data.context("GraphQL response contained no data")
246        }
247    }
248
249    #[derive(Debug, Deserialize)]
250    pub struct GraphQLError {
251        pub message: String,
252    }
253
254    #[derive(Debug, Deserialize)]
255    pub struct CommitAuthorsResponse {
256        pub repository: AuthorsForCommits,
257    }
258
259    pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
260    where
261        T: Deserialize<'de>,
262        D: serde::Deserializer<'de>,
263    {
264        #[derive(Deserialize)]
265        struct Nodes<T> {
266            nodes: Vec<T>,
267        }
268        Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
269    }
270
271    pub fn build_co_authors_query<'a>(
272        org: &str,
273        repo: &str,
274        shas: impl IntoIterator<Item = &'a CommitSha>,
275    ) -> String {
276        const FRAGMENT: &str = r#"
277            ... on Commit {
278                author {
279                    name
280                    email
281                    user { login }
282                }
283                authors(first: 10) {
284                    nodes {
285                        name
286                        email
287                        user { login }
288                    }
289                }
290            }
291        "#;
292
293        let objects = shas
294            .into_iter()
295            .map(|commit_sha| {
296                format!(
297                    "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
298                    sha_prefix = AuthorsForCommits::SHA_PREFIX,
299                    sha = **commit_sha,
300                )
301            })
302            .join("\n");
303
304        format!("{{  repository(owner: \"{org}\", name: \"{repo}\") {{ {objects}  }} }}")
305            .replace("\n", "")
306    }
307}
308
309#[cfg(feature = "octo-client")]
310mod octo_client {
311    use anyhow::{Context, Result};
312    use futures::TryStreamExt as _;
313    use jsonwebtoken::EncodingKey;
314    use octocrab::{
315        Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
316        service::middleware::cache::mem::InMemoryCache,
317    };
318    use serde::de::DeserializeOwned;
319    use tokio::pin;
320
321    use crate::{
322        git::CommitSha,
323        github::{Repository, graph_ql},
324    };
325
326    use super::{
327        AuthorsForCommits, GithubApiClient, GithubLogin, GithubUser, PullRequestComment,
328        PullRequestData, PullRequestReview, ReviewState,
329    };
330
331    const PAGE_SIZE: u8 = 100;
332
333    pub struct OctocrabClient {
334        client: Octocrab,
335    }
336
337    impl OctocrabClient {
338        pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
339            let octocrab = Octocrab::builder()
340                .cache(InMemoryCache::new())
341                .app(
342                    app_id.into(),
343                    EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
344                )
345                .build()?;
346
347            let installations = octocrab
348                .apps()
349                .installations()
350                .send()
351                .await
352                .context("Failed to fetch installations")?
353                .take_items();
354
355            let installation_id = installations
356                .into_iter()
357                .find(|installation| installation.account.login == org)
358                .context("Could not find Zed repository in installations")?
359                .id;
360
361            let client = octocrab.installation(installation_id)?;
362            Ok(Self { client })
363        }
364
365        async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
366            let response: serde_json::Value = self.client.graphql(query).await?;
367            let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
368                .context("Failed to parse GraphQL response envelope")?;
369            parsed.into_data()
370        }
371
372        async fn get_all<T: DeserializeOwned + 'static>(
373            &self,
374            page: Page<T>,
375        ) -> octocrab::Result<Vec<T>> {
376            self.get_filtered(page, |_| true).await
377        }
378
379        async fn get_filtered<T: DeserializeOwned + 'static>(
380            &self,
381            page: Page<T>,
382            predicate: impl Fn(&T) -> bool,
383        ) -> octocrab::Result<Vec<T>> {
384            let stream = page.into_stream(&self.client);
385            pin!(stream);
386
387            let mut results = Vec::new();
388
389            while let Some(item) = stream.try_next().await?
390                && predicate(&item)
391            {
392                results.push(item);
393            }
394
395            Ok(results)
396        }
397    }
398
399    #[async_trait::async_trait(?Send)]
400    impl GithubApiClient for OctocrabClient {
401        async fn get_pull_request(
402            &self,
403            repo: &Repository<'_>,
404            pr_number: u64,
405        ) -> Result<PullRequestData> {
406            let pr = self
407                .client
408                .pulls(repo.owner.as_ref(), repo.name.as_ref())
409                .get(pr_number)
410                .await?;
411            Ok(PullRequestData {
412                number: pr.number,
413                user: pr.user.map(|user| GithubUser { login: user.login }),
414                merged_by: pr.merged_by.map(|user| GithubUser { login: user.login }),
415                labels: pr
416                    .labels
417                    .map(|labels| labels.into_iter().map(|label| label.name).collect()),
418            })
419        }
420
421        async fn get_pull_request_reviews(
422            &self,
423            repo: &Repository<'_>,
424            pr_number: u64,
425        ) -> Result<Vec<PullRequestReview>> {
426            let page = self
427                .client
428                .pulls(repo.owner.as_ref(), repo.name.as_ref())
429                .list_reviews(pr_number)
430                .per_page(PAGE_SIZE)
431                .send()
432                .await?;
433
434            let reviews = self.get_all(page).await?;
435
436            Ok(reviews
437                .into_iter()
438                .map(|review| PullRequestReview {
439                    user: review.user.map(|user| GithubUser { login: user.login }),
440                    state: review.state.map(|state| match state {
441                        OctocrabReviewState::Approved => ReviewState::Approved,
442                        _ => ReviewState::Other,
443                    }),
444                    body: review.body,
445                })
446                .collect())
447        }
448
449        async fn get_pull_request_comments(
450            &self,
451            repo: &Repository<'_>,
452            pr_number: u64,
453        ) -> Result<Vec<PullRequestComment>> {
454            let page = self
455                .client
456                .issues(repo.owner.as_ref(), repo.name.as_ref())
457                .list_comments(pr_number)
458                .per_page(PAGE_SIZE)
459                .send()
460                .await?;
461
462            let comments = self.get_all(page).await?;
463
464            Ok(comments
465                .into_iter()
466                .map(|comment| PullRequestComment {
467                    user: GithubUser {
468                        login: comment.user.login,
469                    },
470                    body: comment.body,
471                })
472                .collect())
473        }
474
475        async fn get_commit_authors(
476            &self,
477            repo: &Repository<'_>,
478            commit_shas: &[&CommitSha],
479        ) -> Result<AuthorsForCommits> {
480            let query = graph_ql::build_co_authors_query(
481                repo.owner.as_ref(),
482                repo.name.as_ref(),
483                commit_shas.iter().copied(),
484            );
485            let query = serde_json::json!({ "query": query });
486            self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
487                .await
488                .map(|response| response.repository)
489        }
490
491        async fn check_repo_write_permission(
492            &self,
493            repo: &Repository<'_>,
494            login: &GithubLogin,
495        ) -> Result<bool> {
496            // Check org membership first - we save ourselves a few request that way
497            let page = self
498                .client
499                .orgs(repo.owner.as_ref())
500                .list_members()
501                .per_page(PAGE_SIZE)
502                .send()
503                .await?;
504
505            let members = self.get_all(page).await?;
506
507            if members
508                .into_iter()
509                .any(|member| member.login == login.as_str())
510            {
511                return Ok(true);
512            }
513
514            // TODO: octocrab fails to deserialize the permission response and
515            // does not adhere to the scheme laid out at
516            // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
517
518            #[derive(serde::Deserialize)]
519            #[serde(rename_all = "lowercase")]
520            enum RepoPermission {
521                Admin,
522                Write,
523                Read,
524                #[serde(other)]
525                Other,
526            }
527
528            #[derive(serde::Deserialize)]
529            struct RepositoryPermissions {
530                permission: RepoPermission,
531            }
532
533            self.client
534                .get::<RepositoryPermissions, _, _>(
535                    format!(
536                        "/repos/{owner}/{repo}/collaborators/{user}/permission",
537                        owner = repo.owner.as_ref(),
538                        repo = repo.name.as_ref(),
539                        user = login.as_str()
540                    ),
541                    None::<&()>,
542                )
543                .await
544                .map(|response| {
545                    matches!(
546                        response.permission,
547                        RepoPermission::Write | RepoPermission::Admin
548                    )
549                })
550                .map_err(Into::into)
551        }
552
553        async fn add_label_to_issue(
554            &self,
555            repo: &Repository<'_>,
556            label: &str,
557            issue_number: u64,
558        ) -> Result<()> {
559            self.client
560                .issues(repo.owner.as_ref(), repo.name.as_ref())
561                .add_labels(issue_number, &[label.to_owned()])
562                .await
563                .map(|_| ())
564                .map_err(Into::into)
565        }
566    }
567}
568
569#[cfg(feature = "octo-client")]
570pub use octo_client::OctocrabClient;