github.rs

  1use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
  2
  3use anyhow::Result;
  4use derive_more::Deref;
  5use serde::Deserialize;
  6
  7use crate::git::CommitSha;
  8
  9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
 10
 11#[derive(Debug, Clone)]
 12pub struct GitHubUser {
 13    pub login: String,
 14}
 15
 16#[derive(Debug, Clone)]
 17pub struct PullRequestData {
 18    pub number: u64,
 19    pub user: Option<GitHubUser>,
 20    pub merged_by: Option<GitHubUser>,
 21}
 22
 23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 24pub enum ReviewState {
 25    Approved,
 26    Other,
 27}
 28
 29#[derive(Debug, Clone)]
 30pub struct PullRequestReview {
 31    pub user: Option<GitHubUser>,
 32    pub state: Option<ReviewState>,
 33}
 34
 35#[derive(Debug, Clone)]
 36pub struct PullRequestComment {
 37    pub user: GitHubUser,
 38    pub body: Option<String>,
 39}
 40
 41#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
 42pub struct GithubLogin {
 43    login: String,
 44}
 45
 46impl GithubLogin {
 47    pub fn new(login: String) -> Self {
 48        Self { login }
 49    }
 50}
 51
 52impl fmt::Display for GithubLogin {
 53    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 54        write!(formatter, "@{}", self.login)
 55    }
 56}
 57
 58#[derive(Debug, Deserialize, Clone)]
 59pub struct CommitAuthor {
 60    name: String,
 61    email: String,
 62    user: Option<GithubLogin>,
 63}
 64
 65impl CommitAuthor {
 66    pub(crate) fn user(&self) -> Option<&GithubLogin> {
 67        self.user.as_ref()
 68    }
 69}
 70
 71impl PartialEq for CommitAuthor {
 72    fn eq(&self, other: &Self) -> bool {
 73        self.user.as_ref().zip(other.user.as_ref()).map_or_else(
 74            || self.email == other.email || self.name == other.name,
 75            |(l, r)| l == r,
 76        )
 77    }
 78}
 79
 80impl fmt::Display for CommitAuthor {
 81    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 82        match self.user.as_ref() {
 83            Some(user) => write!(formatter, "{} ({user})", self.name),
 84            None => write!(formatter, "{} ({})", self.name, self.email),
 85        }
 86    }
 87}
 88
 89#[derive(Debug, Deserialize)]
 90pub struct CommitAuthors {
 91    #[serde(rename = "author")]
 92    primary_author: CommitAuthor,
 93    #[serde(rename = "authors")]
 94    co_authors: Vec<CommitAuthor>,
 95}
 96
 97impl CommitAuthors {
 98    pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
 99        self.co_authors.is_empty().not().then(|| {
100            self.co_authors
101                .iter()
102                .filter(|co_author| *co_author != &self.primary_author)
103        })
104    }
105}
106
107#[derive(Debug, Deserialize, Deref)]
108pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
109
110#[async_trait::async_trait(?Send)]
111pub trait GitHubApiClient {
112    async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
113    async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
114    async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
115    async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
116    async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
117    async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool>;
118    async fn actor_has_repository_write_permission(
119        &self,
120        login: &GithubLogin,
121    ) -> anyhow::Result<bool> {
122        Ok(self.check_org_membership(login).await?
123            || self.check_repo_write_permission(login).await?)
124    }
125    async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>;
126}
127
128#[derive(Deref)]
129pub struct GitHubClient {
130    api: Rc<dyn GitHubApiClient>,
131}
132
133impl GitHubClient {
134    pub fn new(api: Rc<dyn GitHubApiClient>) -> Self {
135        Self { api }
136    }
137
138    #[cfg(feature = "octo-client")]
139    pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
140        let client = OctocrabClient::new(app_id, app_private_key).await?;
141        Ok(Self::new(Rc::new(client)))
142    }
143}
144
145#[cfg(feature = "octo-client")]
146mod octo_client {
147    use anyhow::{Context, Result};
148    use futures::TryStreamExt as _;
149    use itertools::Itertools;
150    use jsonwebtoken::EncodingKey;
151    use octocrab::{
152        Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
153        service::middleware::cache::mem::InMemoryCache,
154    };
155    use serde::de::DeserializeOwned;
156    use tokio::pin;
157
158    use crate::git::CommitSha;
159
160    use super::{
161        AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
162        PullRequestData, PullRequestReview, ReviewState,
163    };
164
165    const PAGE_SIZE: u8 = 100;
166    const ORG: &str = "zed-industries";
167    const REPO: &str = "zed";
168
169    pub struct OctocrabClient {
170        client: Octocrab,
171    }
172
173    impl OctocrabClient {
174        pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
175            let octocrab = Octocrab::builder()
176                .cache(InMemoryCache::new())
177                .app(
178                    app_id.into(),
179                    EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
180                )
181                .build()?;
182
183            let installations = octocrab
184                .apps()
185                .installations()
186                .send()
187                .await
188                .context("Failed to fetch installations")?
189                .take_items();
190
191            let installation_id = installations
192                .into_iter()
193                .find(|installation| installation.account.login == ORG)
194                .context("Could not find Zed repository in installations")?
195                .id;
196
197            let client = octocrab.installation(installation_id)?;
198            Ok(Self { client })
199        }
200
201        fn build_co_authors_query<'a>(shas: impl IntoIterator<Item = &'a CommitSha>) -> String {
202            const FRAGMENT: &str = r#"
203                ... on Commit {
204                    author {
205                        name
206                        email
207                        user { login }
208                    }
209                    authors(first: 10) {
210                        nodes {
211                            name
212                            email
213                            user { login }
214                        }
215                    }
216                }
217            "#;
218
219            let objects: String = shas
220                .into_iter()
221                .map(|commit_sha| {
222                    format!(
223                        "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
224                        sha = **commit_sha
225                    )
226                })
227                .join("\n");
228
229            format!("{{  repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects}  }} }}")
230                .replace("\n", "")
231        }
232
233        async fn graphql<R: octocrab::FromResponse>(
234            &self,
235            query: &serde_json::Value,
236        ) -> octocrab::Result<R> {
237            self.client.graphql(query).await
238        }
239
240        async fn get_all<T: DeserializeOwned + 'static>(
241            &self,
242            page: Page<T>,
243        ) -> octocrab::Result<Vec<T>> {
244            self.get_filtered(page, |_| true).await
245        }
246
247        async fn get_filtered<T: DeserializeOwned + 'static>(
248            &self,
249            page: Page<T>,
250            predicate: impl Fn(&T) -> bool,
251        ) -> octocrab::Result<Vec<T>> {
252            let stream = page.into_stream(&self.client);
253            pin!(stream);
254
255            let mut results = Vec::new();
256
257            while let Some(item) = stream.try_next().await?
258                && predicate(&item)
259            {
260                results.push(item);
261            }
262
263            Ok(results)
264        }
265    }
266
267    #[async_trait::async_trait(?Send)]
268    impl GitHubApiClient for OctocrabClient {
269        async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
270            let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
271            Ok(PullRequestData {
272                number: pr.number,
273                user: pr.user.map(|user| GitHubUser { login: user.login }),
274                merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
275            })
276        }
277
278        async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
279            let page = self
280                .client
281                .pulls(ORG, REPO)
282                .list_reviews(pr_number)
283                .per_page(PAGE_SIZE)
284                .send()
285                .await?;
286
287            let reviews = self.get_all(page).await?;
288
289            Ok(reviews
290                .into_iter()
291                .map(|review| PullRequestReview {
292                    user: review.user.map(|user| GitHubUser { login: user.login }),
293                    state: review.state.map(|state| match state {
294                        OctocrabReviewState::Approved => ReviewState::Approved,
295                        _ => ReviewState::Other,
296                    }),
297                })
298                .collect())
299        }
300
301        async fn get_pull_request_comments(
302            &self,
303            pr_number: u64,
304        ) -> Result<Vec<PullRequestComment>> {
305            let page = self
306                .client
307                .issues(ORG, REPO)
308                .list_comments(pr_number)
309                .per_page(PAGE_SIZE)
310                .send()
311                .await?;
312
313            let comments = self.get_all(page).await?;
314
315            Ok(comments
316                .into_iter()
317                .map(|comment| PullRequestComment {
318                    user: GitHubUser {
319                        login: comment.user.login,
320                    },
321                    body: comment.body,
322                })
323                .collect())
324        }
325
326        async fn get_commit_authors(
327            &self,
328            commit_shas: &[&CommitSha],
329        ) -> Result<AuthorsForCommits> {
330            let query = Self::build_co_authors_query(commit_shas.iter().copied());
331            let query = serde_json::json!({ "query": query });
332            let mut response = self.graphql::<serde_json::Value>(&query).await?;
333
334            response
335                .get_mut("data")
336                .and_then(|data| data.get_mut("repository"))
337                .and_then(|repo| repo.as_object_mut())
338                .ok_or_else(|| anyhow::anyhow!("Unexpected response format!"))
339                .and_then(|commit_data| {
340                    let mut response_map = serde_json::Map::with_capacity(commit_data.len());
341
342                    for (key, value) in commit_data.iter_mut() {
343                        let key_without_prefix = key.strip_prefix("commit").unwrap_or(key);
344                        if let Some(authors) = value.get_mut("authors") {
345                            if let Some(nodes) = authors.get("nodes") {
346                                *authors = nodes.clone();
347                            }
348                        }
349
350                        response_map.insert(key_without_prefix.to_owned(), value.clone());
351                    }
352
353                    serde_json::from_value(serde_json::Value::Object(response_map))
354                        .context("Failed to deserialize commit authors")
355                })
356        }
357
358        async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
359            let page = self
360                .client
361                .orgs(ORG)
362                .list_members()
363                .per_page(PAGE_SIZE)
364                .send()
365                .await?;
366
367            let members = self.get_all(page).await?;
368
369            Ok(members
370                .into_iter()
371                .any(|member| member.login == login.as_str()))
372        }
373
374        async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool> {
375            // TODO: octocrab fails to deserialize the permission response and
376            // does not adhere to the scheme laid out at
377            // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
378
379            #[derive(serde::Deserialize)]
380            #[serde(rename_all = "lowercase")]
381            enum RepoPermission {
382                Admin,
383                Write,
384                Read,
385                #[serde(other)]
386                Other,
387            }
388
389            #[derive(serde::Deserialize)]
390            struct RepositoryPermissions {
391                permission: RepoPermission,
392            }
393
394            self.client
395                .get::<RepositoryPermissions, _, _>(
396                    format!(
397                        "/repos/{ORG}/{REPO}/collaborators/{user}/permission",
398                        user = login.as_str()
399                    ),
400                    None::<&()>,
401                )
402                .await
403                .map(|response| {
404                    matches!(
405                        response.permission,
406                        RepoPermission::Write | RepoPermission::Admin
407                    )
408                })
409                .map_err(Into::into)
410        }
411
412        async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> {
413            if self
414                .get_filtered(
415                    self.client
416                        .issues(ORG, REPO)
417                        .list_labels_for_issue(pr_number)
418                        .per_page(PAGE_SIZE)
419                        .send()
420                        .await?,
421                    |pr_label| pr_label.name == label,
422                )
423                .await
424                .is_ok_and(|l| l.is_empty())
425            {
426                self.client
427                    .issues(ORG, REPO)
428                    .add_labels(pr_number, &[label.to_owned()])
429                    .await?;
430            }
431
432            Ok(())
433        }
434    }
435}
436
437#[cfg(feature = "octo-client")]
438pub use octo_client::OctocrabClient;