github.rs

  1use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
  2
  3use anyhow::Result;
  4use derive_more::Deref;
  5use serde::Deserialize;
  6
  7use crate::git::CommitSha;
  8
  9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
 10
 11#[derive(Debug, Clone)]
 12pub struct GitHubUser {
 13    pub login: String,
 14}
 15
 16#[derive(Debug, Clone)]
 17pub struct PullRequestData {
 18    pub number: u64,
 19    pub user: Option<GitHubUser>,
 20    pub merged_by: Option<GitHubUser>,
 21}
 22
 23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 24pub enum ReviewState {
 25    Approved,
 26    Other,
 27}
 28
 29#[derive(Debug, Clone)]
 30pub struct PullRequestReview {
 31    pub user: Option<GitHubUser>,
 32    pub state: Option<ReviewState>,
 33    pub body: Option<String>,
 34}
 35
 36impl PullRequestReview {
 37    pub fn with_body(self, body: impl ToString) -> Self {
 38        Self {
 39            body: Some(body.to_string()),
 40            ..self
 41        }
 42    }
 43}
 44
 45#[derive(Debug, Clone)]
 46pub struct PullRequestComment {
 47    pub user: GitHubUser,
 48    pub body: Option<String>,
 49}
 50
 51#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
 52pub struct GithubLogin {
 53    login: String,
 54}
 55
 56impl GithubLogin {
 57    pub fn new(login: String) -> Self {
 58        Self { login }
 59    }
 60}
 61
 62impl fmt::Display for GithubLogin {
 63    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 64        write!(formatter, "@{}", self.login)
 65    }
 66}
 67
 68#[derive(Debug, Deserialize, Clone)]
 69pub struct CommitAuthor {
 70    name: String,
 71    email: String,
 72    user: Option<GithubLogin>,
 73}
 74
 75impl CommitAuthor {
 76    pub(crate) fn user(&self) -> Option<&GithubLogin> {
 77        self.user.as_ref()
 78    }
 79}
 80
 81impl PartialEq for CommitAuthor {
 82    fn eq(&self, other: &Self) -> bool {
 83        self.user.as_ref().zip(other.user.as_ref()).map_or_else(
 84            || self.email == other.email || self.name == other.name,
 85            |(l, r)| l == r,
 86        )
 87    }
 88}
 89
 90impl fmt::Display for CommitAuthor {
 91    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 92        match self.user.as_ref() {
 93            Some(user) => write!(formatter, "{} ({user})", self.name),
 94            None => write!(formatter, "{} ({})", self.name, self.email),
 95        }
 96    }
 97}
 98
 99#[derive(Debug, Deserialize)]
100pub struct CommitAuthors {
101    #[serde(rename = "author")]
102    primary_author: CommitAuthor,
103    #[serde(rename = "authors")]
104    co_authors: Vec<CommitAuthor>,
105}
106
107impl CommitAuthors {
108    pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
109        self.co_authors.is_empty().not().then(|| {
110            self.co_authors
111                .iter()
112                .filter(|co_author| *co_author != &self.primary_author)
113        })
114    }
115}
116
117#[derive(Debug, Deserialize, Deref)]
118pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
119
120#[async_trait::async_trait(?Send)]
121pub trait GitHubApiClient {
122    async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
123    async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
124    async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
125    async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
126    async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
127    async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool>;
128    async fn actor_has_repository_write_permission(
129        &self,
130        login: &GithubLogin,
131    ) -> anyhow::Result<bool> {
132        Ok(self.check_org_membership(login).await?
133            || self.check_repo_write_permission(login).await?)
134    }
135    async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>;
136}
137
138#[derive(Deref)]
139pub struct GitHubClient {
140    api: Rc<dyn GitHubApiClient>,
141}
142
143impl GitHubClient {
144    pub fn new(api: Rc<dyn GitHubApiClient>) -> Self {
145        Self { api }
146    }
147
148    #[cfg(feature = "octo-client")]
149    pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
150        let client = OctocrabClient::new(app_id, app_private_key).await?;
151        Ok(Self::new(Rc::new(client)))
152    }
153}
154
155#[cfg(feature = "octo-client")]
156mod octo_client {
157    use anyhow::{Context, Result};
158    use futures::TryStreamExt as _;
159    use itertools::Itertools;
160    use jsonwebtoken::EncodingKey;
161    use octocrab::{
162        Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
163        service::middleware::cache::mem::InMemoryCache,
164    };
165    use serde::de::DeserializeOwned;
166    use tokio::pin;
167
168    use crate::git::CommitSha;
169
170    use super::{
171        AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
172        PullRequestData, PullRequestReview, ReviewState,
173    };
174
175    const PAGE_SIZE: u8 = 100;
176    const ORG: &str = "zed-industries";
177    const REPO: &str = "zed";
178
179    pub struct OctocrabClient {
180        client: Octocrab,
181    }
182
183    impl OctocrabClient {
184        pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
185            let octocrab = Octocrab::builder()
186                .cache(InMemoryCache::new())
187                .app(
188                    app_id.into(),
189                    EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
190                )
191                .build()?;
192
193            let installations = octocrab
194                .apps()
195                .installations()
196                .send()
197                .await
198                .context("Failed to fetch installations")?
199                .take_items();
200
201            let installation_id = installations
202                .into_iter()
203                .find(|installation| installation.account.login == ORG)
204                .context("Could not find Zed repository in installations")?
205                .id;
206
207            let client = octocrab.installation(installation_id)?;
208            Ok(Self { client })
209        }
210
211        fn build_co_authors_query<'a>(shas: impl IntoIterator<Item = &'a CommitSha>) -> String {
212            const FRAGMENT: &str = r#"
213                ... on Commit {
214                    author {
215                        name
216                        email
217                        user { login }
218                    }
219                    authors(first: 10) {
220                        nodes {
221                            name
222                            email
223                            user { login }
224                        }
225                    }
226                }
227            "#;
228
229            let objects: String = shas
230                .into_iter()
231                .map(|commit_sha| {
232                    format!(
233                        "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
234                        sha = **commit_sha
235                    )
236                })
237                .join("\n");
238
239            format!("{{  repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects}  }} }}")
240                .replace("\n", "")
241        }
242
243        async fn graphql<R: octocrab::FromResponse>(
244            &self,
245            query: &serde_json::Value,
246        ) -> octocrab::Result<R> {
247            self.client.graphql(query).await
248        }
249
250        async fn get_all<T: DeserializeOwned + 'static>(
251            &self,
252            page: Page<T>,
253        ) -> octocrab::Result<Vec<T>> {
254            self.get_filtered(page, |_| true).await
255        }
256
257        async fn get_filtered<T: DeserializeOwned + 'static>(
258            &self,
259            page: Page<T>,
260            predicate: impl Fn(&T) -> bool,
261        ) -> octocrab::Result<Vec<T>> {
262            let stream = page.into_stream(&self.client);
263            pin!(stream);
264
265            let mut results = Vec::new();
266
267            while let Some(item) = stream.try_next().await?
268                && predicate(&item)
269            {
270                results.push(item);
271            }
272
273            Ok(results)
274        }
275    }
276
277    #[async_trait::async_trait(?Send)]
278    impl GitHubApiClient for OctocrabClient {
279        async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
280            let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
281            Ok(PullRequestData {
282                number: pr.number,
283                user: pr.user.map(|user| GitHubUser { login: user.login }),
284                merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
285            })
286        }
287
288        async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
289            let page = self
290                .client
291                .pulls(ORG, REPO)
292                .list_reviews(pr_number)
293                .per_page(PAGE_SIZE)
294                .send()
295                .await?;
296
297            let reviews = self.get_all(page).await?;
298
299            Ok(reviews
300                .into_iter()
301                .map(|review| PullRequestReview {
302                    user: review.user.map(|user| GitHubUser { login: user.login }),
303                    state: review.state.map(|state| match state {
304                        OctocrabReviewState::Approved => ReviewState::Approved,
305                        _ => ReviewState::Other,
306                    }),
307                    body: review.body,
308                })
309                .collect())
310        }
311
312        async fn get_pull_request_comments(
313            &self,
314            pr_number: u64,
315        ) -> Result<Vec<PullRequestComment>> {
316            let page = self
317                .client
318                .issues(ORG, REPO)
319                .list_comments(pr_number)
320                .per_page(PAGE_SIZE)
321                .send()
322                .await?;
323
324            let comments = self.get_all(page).await?;
325
326            Ok(comments
327                .into_iter()
328                .map(|comment| PullRequestComment {
329                    user: GitHubUser {
330                        login: comment.user.login,
331                    },
332                    body: comment.body,
333                })
334                .collect())
335        }
336
337        async fn get_commit_authors(
338            &self,
339            commit_shas: &[&CommitSha],
340        ) -> Result<AuthorsForCommits> {
341            let query = Self::build_co_authors_query(commit_shas.iter().copied());
342            let query = serde_json::json!({ "query": query });
343            let mut response = self.graphql::<serde_json::Value>(&query).await?;
344
345            response
346                .get_mut("data")
347                .and_then(|data| data.get_mut("repository"))
348                .and_then(|repo| repo.as_object_mut())
349                .ok_or_else(|| anyhow::anyhow!("Unexpected response format!"))
350                .and_then(|commit_data| {
351                    let mut response_map = serde_json::Map::with_capacity(commit_data.len());
352
353                    for (key, value) in commit_data.iter_mut() {
354                        let key_without_prefix = key.strip_prefix("commit").unwrap_or(key);
355                        if let Some(authors) = value.get_mut("authors") {
356                            if let Some(nodes) = authors.get("nodes") {
357                                *authors = nodes.clone();
358                            }
359                        }
360
361                        response_map.insert(key_without_prefix.to_owned(), value.clone());
362                    }
363
364                    serde_json::from_value(serde_json::Value::Object(response_map))
365                        .context("Failed to deserialize commit authors")
366                })
367        }
368
369        async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
370            let page = self
371                .client
372                .orgs(ORG)
373                .list_members()
374                .per_page(PAGE_SIZE)
375                .send()
376                .await?;
377
378            let members = self.get_all(page).await?;
379
380            Ok(members
381                .into_iter()
382                .any(|member| member.login == login.as_str()))
383        }
384
385        async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool> {
386            // TODO: octocrab fails to deserialize the permission response and
387            // does not adhere to the scheme laid out at
388            // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
389
390            #[derive(serde::Deserialize)]
391            #[serde(rename_all = "lowercase")]
392            enum RepoPermission {
393                Admin,
394                Write,
395                Read,
396                #[serde(other)]
397                Other,
398            }
399
400            #[derive(serde::Deserialize)]
401            struct RepositoryPermissions {
402                permission: RepoPermission,
403            }
404
405            self.client
406                .get::<RepositoryPermissions, _, _>(
407                    format!(
408                        "/repos/{ORG}/{REPO}/collaborators/{user}/permission",
409                        user = login.as_str()
410                    ),
411                    None::<&()>,
412                )
413                .await
414                .map(|response| {
415                    matches!(
416                        response.permission,
417                        RepoPermission::Write | RepoPermission::Admin
418                    )
419                })
420                .map_err(Into::into)
421        }
422
423        async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> {
424            if self
425                .get_filtered(
426                    self.client
427                        .issues(ORG, REPO)
428                        .list_labels_for_issue(pr_number)
429                        .per_page(PAGE_SIZE)
430                        .send()
431                        .await?,
432                    |pr_label| pr_label.name == label,
433                )
434                .await
435                .is_ok_and(|l| l.is_empty())
436            {
437                self.client
438                    .issues(ORG, REPO)
439                    .add_labels(pr_number, &[label.to_owned()])
440                    .await?;
441            }
442
443            Ok(())
444        }
445    }
446}
447
448#[cfg(feature = "octo-client")]
449pub use octo_client::OctocrabClient;