github.rs

  1use std::{borrow::Cow, collections::HashMap, fmt, ops::Not, rc::Rc};
  2
  3use anyhow::Result;
  4use derive_more::Deref;
  5use serde::Deserialize;
  6
  7use crate::git::CommitSha;
  8
  9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
 10
 11#[derive(Debug, Clone)]
 12pub struct GithubUser {
 13    pub login: String,
 14}
 15
 16#[derive(Debug, Clone)]
 17pub struct PullRequestData {
 18    pub number: u64,
 19    pub user: Option<GithubUser>,
 20    pub merged_by: Option<GithubUser>,
 21    pub labels: Option<Vec<String>>,
 22}
 23
 24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 25pub enum ReviewState {
 26    Approved,
 27    Other,
 28}
 29
 30#[derive(Debug, Clone)]
 31pub struct PullRequestReview {
 32    pub user: Option<GithubUser>,
 33    pub state: Option<ReviewState>,
 34    pub body: Option<String>,
 35}
 36
 37impl PullRequestReview {
 38    pub fn with_body(self, body: impl ToString) -> Self {
 39        Self {
 40            body: Some(body.to_string()),
 41            ..self
 42        }
 43    }
 44}
 45
 46#[derive(Debug, Clone)]
 47pub struct PullRequestComment {
 48    pub user: GithubUser,
 49    pub body: Option<String>,
 50}
 51
 52#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
 53pub struct GithubLogin {
 54    login: String,
 55}
 56
 57impl GithubLogin {
 58    pub fn new(login: String) -> Self {
 59        Self { login }
 60    }
 61}
 62
 63impl fmt::Display for GithubLogin {
 64    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 65        write!(formatter, "@{}", self.login)
 66    }
 67}
 68
 69#[derive(Debug, Deserialize, Clone)]
 70pub struct CommitAuthor {
 71    name: String,
 72    email: String,
 73    user: Option<GithubLogin>,
 74}
 75
 76impl CommitAuthor {
 77    pub(crate) fn user(&self) -> Option<&GithubLogin> {
 78        self.user.as_ref()
 79    }
 80}
 81
 82impl PartialEq for CommitAuthor {
 83    fn eq(&self, other: &Self) -> bool {
 84        self.user.as_ref().zip(other.user.as_ref()).map_or_else(
 85            || self.email == other.email || self.name == other.name,
 86            |(l, r)| l == r,
 87        )
 88    }
 89}
 90
 91impl fmt::Display for CommitAuthor {
 92    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 93        match self.user.as_ref() {
 94            Some(user) => write!(formatter, "{} ({user})", self.name),
 95            None => write!(formatter, "{} ({})", self.name, self.email),
 96        }
 97    }
 98}
 99
100#[derive(Debug, Deserialize)]
101pub struct CommitAuthors {
102    #[serde(rename = "author")]
103    primary_author: CommitAuthor,
104    #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
105    co_authors: Vec<CommitAuthor>,
106}
107
108impl CommitAuthors {
109    pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
110        self.co_authors.is_empty().not().then(|| {
111            self.co_authors
112                .iter()
113                .filter(|co_author| *co_author != &self.primary_author)
114        })
115    }
116}
117
118#[derive(Debug, Deref)]
119pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
120
121impl AuthorsForCommits {
122    const SHA_PREFIX: &'static str = "commit";
123}
124
125impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
126    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
127    where
128        D: serde::Deserializer<'de>,
129    {
130        let raw = HashMap::<String, CommitAuthors>::deserialize(deserializer)?;
131        let map = raw
132            .into_iter()
133            .map(|(key, value)| {
134                let sha = key
135                    .strip_prefix(AuthorsForCommits::SHA_PREFIX)
136                    .unwrap_or(&key);
137                (CommitSha::new(sha.to_owned()), value)
138            })
139            .collect();
140        Ok(Self(map))
141    }
142}
143
144#[derive(Clone)]
145pub struct Repository<'a> {
146    owner: Cow<'a, str>,
147    name: Cow<'a, str>,
148}
149
150impl<'a> Repository<'a> {
151    pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
152
153    pub fn new(owner: &'a str, name: &'a str) -> Self {
154        Self {
155            owner: Cow::Borrowed(owner),
156            name: Cow::Borrowed(name),
157        }
158    }
159
160    pub fn owner(&self) -> &str {
161        &self.owner
162    }
163
164    pub fn name(&self) -> &str {
165        &self.name
166    }
167}
168
169impl Repository<'static> {
170    pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
171        Self {
172            owner: Cow::Borrowed(owner),
173            name: Cow::Borrowed(name),
174        }
175    }
176}
177
178#[async_trait::async_trait(?Send)]
179pub trait GithubApiClient {
180    async fn get_pull_request(
181        &self,
182        repo: &Repository<'_>,
183        pr_number: u64,
184    ) -> Result<PullRequestData>;
185    async fn get_pull_request_reviews(
186        &self,
187        repo: &Repository<'_>,
188        pr_number: u64,
189    ) -> Result<Vec<PullRequestReview>>;
190    async fn get_pull_request_comments(
191        &self,
192        repo: &Repository<'_>,
193        pr_number: u64,
194    ) -> Result<Vec<PullRequestComment>>;
195    async fn get_commit_authors(
196        &self,
197        repo: &Repository<'_>,
198        commit_shas: &[&CommitSha],
199    ) -> Result<AuthorsForCommits>;
200    async fn check_repo_write_permission(
201        &self,
202        repo: &Repository<'_>,
203        login: &GithubLogin,
204    ) -> Result<bool>;
205    async fn add_label_to_issue(
206        &self,
207        repo: &Repository<'_>,
208        label: &str,
209        issue_number: u64,
210    ) -> Result<()>;
211}
212
213#[derive(Deref)]
214pub struct GithubClient {
215    api: Rc<dyn GithubApiClient>,
216}
217
218impl GithubClient {
219    pub fn new(api: Rc<dyn GithubApiClient>) -> Self {
220        Self { api }
221    }
222
223    #[cfg(feature = "octo-client")]
224    pub async fn for_app_in_repo(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
225        let client = OctocrabClient::new(app_id, app_private_key, org).await?;
226        Ok(Self::new(Rc::new(client)))
227    }
228}
229
230pub mod graph_ql {
231    use anyhow::{Context as _, Result};
232    use itertools::Itertools as _;
233    use serde::Deserialize;
234
235    use crate::git::CommitSha;
236
237    use super::AuthorsForCommits;
238
239    #[derive(Debug, Deserialize)]
240    pub struct GraphQLResponse<T> {
241        pub data: Option<T>,
242        pub errors: Option<Vec<GraphQLError>>,
243    }
244
245    impl<T> GraphQLResponse<T> {
246        pub fn into_data(self) -> Result<T> {
247            if let Some(errors) = &self.errors {
248                if !errors.is_empty() {
249                    let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
250                    anyhow::bail!("GraphQL error: {messages}");
251                }
252            }
253
254            self.data.context("GraphQL response contained no data")
255        }
256    }
257
258    #[derive(Debug, Deserialize)]
259    pub struct GraphQLError {
260        pub message: String,
261    }
262
263    #[derive(Debug, Deserialize)]
264    pub struct CommitAuthorsResponse {
265        pub repository: AuthorsForCommits,
266    }
267
268    pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
269    where
270        T: Deserialize<'de>,
271        D: serde::Deserializer<'de>,
272    {
273        #[derive(Deserialize)]
274        struct Nodes<T> {
275            nodes: Vec<T>,
276        }
277        Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
278    }
279
280    pub fn build_co_authors_query<'a>(
281        org: &str,
282        repo: &str,
283        shas: impl IntoIterator<Item = &'a CommitSha>,
284    ) -> String {
285        const FRAGMENT: &str = r#"
286            ... on Commit {
287                author {
288                    name
289                    email
290                    user { login }
291                }
292                authors(first: 10) {
293                    nodes {
294                        name
295                        email
296                        user { login }
297                    }
298                }
299            }
300        "#;
301
302        let objects = shas
303            .into_iter()
304            .map(|commit_sha| {
305                format!(
306                    "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
307                    sha_prefix = AuthorsForCommits::SHA_PREFIX,
308                    sha = **commit_sha,
309                )
310            })
311            .join("\n");
312
313        format!("{{  repository(owner: \"{org}\", name: \"{repo}\") {{ {objects}  }} }}")
314            .replace("\n", "")
315    }
316}
317
318#[cfg(feature = "octo-client")]
319mod octo_client {
320    use anyhow::{Context, Result};
321    use futures::TryStreamExt as _;
322    use jsonwebtoken::EncodingKey;
323    use octocrab::{
324        Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
325        service::middleware::cache::mem::InMemoryCache,
326    };
327    use serde::de::DeserializeOwned;
328    use tokio::pin;
329
330    use crate::{
331        git::CommitSha,
332        github::{Repository, graph_ql},
333    };
334
335    use super::{
336        AuthorsForCommits, GithubApiClient, GithubLogin, GithubUser, PullRequestComment,
337        PullRequestData, PullRequestReview, ReviewState,
338    };
339
340    const PAGE_SIZE: u8 = 100;
341
342    pub struct OctocrabClient {
343        client: Octocrab,
344    }
345
346    impl OctocrabClient {
347        pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
348            let octocrab = Octocrab::builder()
349                .cache(InMemoryCache::new())
350                .app(
351                    app_id.into(),
352                    EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
353                )
354                .build()?;
355
356            let installations = octocrab
357                .apps()
358                .installations()
359                .send()
360                .await
361                .context("Failed to fetch installations")?
362                .take_items();
363
364            let installation_id = installations
365                .into_iter()
366                .find(|installation| installation.account.login == org)
367                .context("Could not find Zed repository in installations")?
368                .id;
369
370            let client = octocrab.installation(installation_id)?;
371            Ok(Self { client })
372        }
373
374        async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
375            let response: serde_json::Value = self.client.graphql(query).await?;
376            let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
377                .context("Failed to parse GraphQL response envelope")?;
378            parsed.into_data()
379        }
380
381        async fn get_all<T: DeserializeOwned + 'static>(
382            &self,
383            page: Page<T>,
384        ) -> octocrab::Result<Vec<T>> {
385            self.get_filtered(page, |_| true).await
386        }
387
388        async fn get_filtered<T: DeserializeOwned + 'static>(
389            &self,
390            page: Page<T>,
391            predicate: impl Fn(&T) -> bool,
392        ) -> octocrab::Result<Vec<T>> {
393            let stream = page.into_stream(&self.client);
394            pin!(stream);
395
396            let mut results = Vec::new();
397
398            while let Some(item) = stream.try_next().await?
399                && predicate(&item)
400            {
401                results.push(item);
402            }
403
404            Ok(results)
405        }
406    }
407
408    #[async_trait::async_trait(?Send)]
409    impl GithubApiClient for OctocrabClient {
410        async fn get_pull_request(
411            &self,
412            repo: &Repository<'_>,
413            pr_number: u64,
414        ) -> Result<PullRequestData> {
415            let pr = self
416                .client
417                .pulls(repo.owner.as_ref(), repo.name.as_ref())
418                .get(pr_number)
419                .await?;
420            Ok(PullRequestData {
421                number: pr.number,
422                user: pr.user.map(|user| GithubUser { login: user.login }),
423                merged_by: pr.merged_by.map(|user| GithubUser { login: user.login }),
424                labels: pr
425                    .labels
426                    .map(|labels| labels.into_iter().map(|label| label.name).collect()),
427            })
428        }
429
430        async fn get_pull_request_reviews(
431            &self,
432            repo: &Repository<'_>,
433            pr_number: u64,
434        ) -> Result<Vec<PullRequestReview>> {
435            let page = self
436                .client
437                .pulls(repo.owner.as_ref(), repo.name.as_ref())
438                .list_reviews(pr_number)
439                .per_page(PAGE_SIZE)
440                .send()
441                .await?;
442
443            let reviews = self.get_all(page).await?;
444
445            Ok(reviews
446                .into_iter()
447                .map(|review| PullRequestReview {
448                    user: review.user.map(|user| GithubUser { login: user.login }),
449                    state: review.state.map(|state| match state {
450                        OctocrabReviewState::Approved => ReviewState::Approved,
451                        _ => ReviewState::Other,
452                    }),
453                    body: review.body,
454                })
455                .collect())
456        }
457
458        async fn get_pull_request_comments(
459            &self,
460            repo: &Repository<'_>,
461            pr_number: u64,
462        ) -> Result<Vec<PullRequestComment>> {
463            let page = self
464                .client
465                .issues(repo.owner.as_ref(), repo.name.as_ref())
466                .list_comments(pr_number)
467                .per_page(PAGE_SIZE)
468                .send()
469                .await?;
470
471            let comments = self.get_all(page).await?;
472
473            Ok(comments
474                .into_iter()
475                .map(|comment| PullRequestComment {
476                    user: GithubUser {
477                        login: comment.user.login,
478                    },
479                    body: comment.body,
480                })
481                .collect())
482        }
483
484        async fn get_commit_authors(
485            &self,
486            repo: &Repository<'_>,
487            commit_shas: &[&CommitSha],
488        ) -> Result<AuthorsForCommits> {
489            let query = graph_ql::build_co_authors_query(
490                repo.owner.as_ref(),
491                repo.name.as_ref(),
492                commit_shas.iter().copied(),
493            );
494            let query = serde_json::json!({ "query": query });
495            self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
496                .await
497                .map(|response| response.repository)
498        }
499
500        async fn check_repo_write_permission(
501            &self,
502            repo: &Repository<'_>,
503            login: &GithubLogin,
504        ) -> Result<bool> {
505            // Check org membership first - we save ourselves a few request that way
506            let page = self
507                .client
508                .orgs(repo.owner.as_ref())
509                .list_members()
510                .per_page(PAGE_SIZE)
511                .send()
512                .await?;
513
514            let members = self.get_all(page).await?;
515
516            if members
517                .into_iter()
518                .any(|member| member.login == login.as_str())
519            {
520                return Ok(true);
521            }
522
523            // TODO: octocrab fails to deserialize the permission response and
524            // does not adhere to the scheme laid out at
525            // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
526
527            #[derive(serde::Deserialize)]
528            #[serde(rename_all = "lowercase")]
529            enum RepoPermission {
530                Admin,
531                Write,
532                Read,
533                #[serde(other)]
534                Other,
535            }
536
537            #[derive(serde::Deserialize)]
538            struct RepositoryPermissions {
539                permission: RepoPermission,
540            }
541
542            self.client
543                .get::<RepositoryPermissions, _, _>(
544                    format!(
545                        "/repos/{owner}/{repo}/collaborators/{user}/permission",
546                        owner = repo.owner.as_ref(),
547                        repo = repo.name.as_ref(),
548                        user = login.as_str()
549                    ),
550                    None::<&()>,
551                )
552                .await
553                .map(|response| {
554                    matches!(
555                        response.permission,
556                        RepoPermission::Write | RepoPermission::Admin
557                    )
558                })
559                .map_err(Into::into)
560        }
561
562        async fn add_label_to_issue(
563            &self,
564            repo: &Repository<'_>,
565            label: &str,
566            issue_number: u64,
567        ) -> Result<()> {
568            self.client
569                .issues(repo.owner.as_ref(), repo.name.as_ref())
570                .add_labels(issue_number, &[label.to_owned()])
571                .await
572                .map(|_| ())
573                .map_err(Into::into)
574        }
575    }
576}
577
578#[cfg(feature = "octo-client")]
579pub use octo_client::OctocrabClient;