github.rs

  1use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
  2
  3use anyhow::Result;
  4use derive_more::Deref;
  5use serde::Deserialize;
  6
  7use crate::git::CommitSha;
  8
  9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
 10
 11#[derive(Debug, Clone)]
 12pub struct GitHubUser {
 13    pub login: String,
 14}
 15
 16#[derive(Debug, Clone)]
 17pub struct PullRequestData {
 18    pub number: u64,
 19    pub user: Option<GitHubUser>,
 20    pub merged_by: Option<GitHubUser>,
 21}
 22
 23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 24pub enum ReviewState {
 25    Approved,
 26    Other,
 27}
 28
 29#[derive(Debug, Clone)]
 30pub struct PullRequestReview {
 31    pub user: Option<GitHubUser>,
 32    pub state: Option<ReviewState>,
 33}
 34
 35#[derive(Debug, Clone)]
 36pub struct PullRequestComment {
 37    pub user: GitHubUser,
 38    pub body: Option<String>,
 39}
 40
 41#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
 42pub struct GithubLogin {
 43    login: String,
 44}
 45
 46impl GithubLogin {
 47    pub(crate) fn new(login: String) -> Self {
 48        Self { login }
 49    }
 50}
 51
 52impl fmt::Display for GithubLogin {
 53    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 54        write!(formatter, "@{}", self.login)
 55    }
 56}
 57
 58#[derive(Debug, Deserialize, Clone)]
 59pub struct CommitAuthor {
 60    name: String,
 61    email: String,
 62    user: Option<GithubLogin>,
 63}
 64
 65impl CommitAuthor {
 66    pub(crate) fn user(&self) -> Option<&GithubLogin> {
 67        self.user.as_ref()
 68    }
 69}
 70
 71impl PartialEq for CommitAuthor {
 72    fn eq(&self, other: &Self) -> bool {
 73        self.user.as_ref().zip(other.user.as_ref()).map_or_else(
 74            || self.email == other.email || self.name == other.name,
 75            |(l, r)| l == r,
 76        )
 77    }
 78}
 79
 80impl fmt::Display for CommitAuthor {
 81    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
 82        match self.user.as_ref() {
 83            Some(user) => write!(formatter, "{} ({user})", self.name),
 84            None => write!(formatter, "{} ({})", self.name, self.email),
 85        }
 86    }
 87}
 88
 89#[derive(Debug, Deserialize)]
 90pub struct CommitAuthors {
 91    #[serde(rename = "author")]
 92    primary_author: CommitAuthor,
 93    #[serde(rename = "authors")]
 94    co_authors: Vec<CommitAuthor>,
 95}
 96
 97impl CommitAuthors {
 98    pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
 99        self.co_authors.is_empty().not().then(|| {
100            self.co_authors
101                .iter()
102                .filter(|co_author| *co_author != &self.primary_author)
103        })
104    }
105}
106
107#[derive(Debug, Deserialize, Deref)]
108pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
109
110#[async_trait::async_trait(?Send)]
111pub trait GitHubApiClient {
112    async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
113    async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
114    async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
115    async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
116    async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
117    async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>;
118}
119
120pub struct GitHubClient {
121    api: Rc<dyn GitHubApiClient>,
122}
123
124impl GitHubClient {
125    pub fn new(api: Rc<dyn GitHubApiClient>) -> Self {
126        Self { api }
127    }
128
129    #[cfg(feature = "octo-client")]
130    pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
131        let client = OctocrabClient::new(app_id, app_private_key).await?;
132        Ok(Self::new(Rc::new(client)))
133    }
134
135    pub async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
136        self.api.get_pull_request(pr_number).await
137    }
138
139    pub async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
140        self.api.get_pull_request_reviews(pr_number).await
141    }
142
143    pub async fn get_pull_request_comments(
144        &self,
145        pr_number: u64,
146    ) -> Result<Vec<PullRequestComment>> {
147        self.api.get_pull_request_comments(pr_number).await
148    }
149
150    pub async fn get_commit_authors<'a>(
151        &self,
152        commit_shas: impl IntoIterator<Item = &'a CommitSha>,
153    ) -> Result<AuthorsForCommits> {
154        let shas: Vec<&CommitSha> = commit_shas.into_iter().collect();
155        self.api.get_commit_authors(&shas).await
156    }
157
158    pub async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
159        self.api.check_org_membership(login).await
160    }
161
162    pub async fn add_label_to_pull_request(&self, label: &str, pr_number: u64) -> Result<()> {
163        self.api
164            .ensure_pull_request_has_label(label, pr_number)
165            .await
166    }
167}
168
169#[cfg(feature = "octo-client")]
170mod octo_client {
171    use anyhow::{Context, Result};
172    use futures::TryStreamExt as _;
173    use itertools::Itertools;
174    use jsonwebtoken::EncodingKey;
175    use octocrab::{
176        Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
177        service::middleware::cache::mem::InMemoryCache,
178    };
179    use serde::de::DeserializeOwned;
180    use tokio::pin;
181
182    use crate::git::CommitSha;
183
184    use super::{
185        AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
186        PullRequestData, PullRequestReview, ReviewState,
187    };
188
189    const PAGE_SIZE: u8 = 100;
190    const ORG: &str = "zed-industries";
191    const REPO: &str = "zed";
192
193    pub struct OctocrabClient {
194        client: Octocrab,
195    }
196
197    impl OctocrabClient {
198        pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
199            let octocrab = Octocrab::builder()
200                .cache(InMemoryCache::new())
201                .app(
202                    app_id.into(),
203                    EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
204                )
205                .build()?;
206
207            let installations = octocrab
208                .apps()
209                .installations()
210                .send()
211                .await
212                .context("Failed to fetch installations")?
213                .take_items();
214
215            let installation_id = installations
216                .into_iter()
217                .find(|installation| installation.account.login == ORG)
218                .context("Could not find Zed repository in installations")?
219                .id;
220
221            let client = octocrab.installation(installation_id)?;
222            Ok(Self { client })
223        }
224
225        fn build_co_authors_query<'a>(shas: impl IntoIterator<Item = &'a CommitSha>) -> String {
226            const FRAGMENT: &str = r#"
227                ... on Commit {
228                    author {
229                        name
230                        email
231                        user { login }
232                    }
233                    authors(first: 10) {
234                        nodes {
235                            name
236                            email
237                            user { login }
238                        }
239                    }
240                }
241            "#;
242
243            let objects: String = shas
244                .into_iter()
245                .map(|commit_sha| {
246                    format!(
247                        "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
248                        sha = **commit_sha
249                    )
250                })
251                .join("\n");
252
253            format!("{{  repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects}  }} }}")
254                .replace("\n", "")
255        }
256
257        async fn graphql<R: octocrab::FromResponse>(
258            &self,
259            query: &serde_json::Value,
260        ) -> octocrab::Result<R> {
261            self.client.graphql(query).await
262        }
263
264        async fn get_all<T: DeserializeOwned + 'static>(
265            &self,
266            page: Page<T>,
267        ) -> octocrab::Result<Vec<T>> {
268            self.get_filtered(page, |_| true).await
269        }
270
271        async fn get_filtered<T: DeserializeOwned + 'static>(
272            &self,
273            page: Page<T>,
274            predicate: impl Fn(&T) -> bool,
275        ) -> octocrab::Result<Vec<T>> {
276            let stream = page.into_stream(&self.client);
277            pin!(stream);
278
279            let mut results = Vec::new();
280
281            while let Some(item) = stream.try_next().await?
282                && predicate(&item)
283            {
284                results.push(item);
285            }
286
287            Ok(results)
288        }
289    }
290
291    #[async_trait::async_trait(?Send)]
292    impl GitHubApiClient for OctocrabClient {
293        async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
294            let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
295            Ok(PullRequestData {
296                number: pr.number,
297                user: pr.user.map(|user| GitHubUser { login: user.login }),
298                merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
299            })
300        }
301
302        async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
303            let page = self
304                .client
305                .pulls(ORG, REPO)
306                .list_reviews(pr_number)
307                .per_page(PAGE_SIZE)
308                .send()
309                .await?;
310
311            let reviews = self.get_all(page).await?;
312
313            Ok(reviews
314                .into_iter()
315                .map(|review| PullRequestReview {
316                    user: review.user.map(|user| GitHubUser { login: user.login }),
317                    state: review.state.map(|state| match state {
318                        OctocrabReviewState::Approved => ReviewState::Approved,
319                        _ => ReviewState::Other,
320                    }),
321                })
322                .collect())
323        }
324
325        async fn get_pull_request_comments(
326            &self,
327            pr_number: u64,
328        ) -> Result<Vec<PullRequestComment>> {
329            let page = self
330                .client
331                .issues(ORG, REPO)
332                .list_comments(pr_number)
333                .per_page(PAGE_SIZE)
334                .send()
335                .await?;
336
337            let comments = self.get_all(page).await?;
338
339            Ok(comments
340                .into_iter()
341                .map(|comment| PullRequestComment {
342                    user: GitHubUser {
343                        login: comment.user.login,
344                    },
345                    body: comment.body,
346                })
347                .collect())
348        }
349
350        async fn get_commit_authors(
351            &self,
352            commit_shas: &[&CommitSha],
353        ) -> Result<AuthorsForCommits> {
354            let query = Self::build_co_authors_query(commit_shas.iter().copied());
355            let query = serde_json::json!({ "query": query });
356            let mut response = self.graphql::<serde_json::Value>(&query).await?;
357
358            response
359                .get_mut("data")
360                .and_then(|data| data.get_mut("repository"))
361                .and_then(|repo| repo.as_object_mut())
362                .ok_or_else(|| anyhow::anyhow!("Unexpected response format!"))
363                .and_then(|commit_data| {
364                    let mut response_map = serde_json::Map::with_capacity(commit_data.len());
365
366                    for (key, value) in commit_data.iter_mut() {
367                        let key_without_prefix = key.strip_prefix("commit").unwrap_or(key);
368                        if let Some(authors) = value.get_mut("authors") {
369                            if let Some(nodes) = authors.get("nodes") {
370                                *authors = nodes.clone();
371                            }
372                        }
373
374                        response_map.insert(key_without_prefix.to_owned(), value.clone());
375                    }
376
377                    serde_json::from_value(serde_json::Value::Object(response_map))
378                        .context("Failed to deserialize commit authors")
379                })
380        }
381
382        async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
383            let page = self
384                .client
385                .orgs(ORG)
386                .list_members()
387                .per_page(PAGE_SIZE)
388                .send()
389                .await?;
390
391            let members = self.get_all(page).await?;
392
393            Ok(members
394                .into_iter()
395                .any(|member| member.login == login.as_str()))
396        }
397
398        async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> {
399            if self
400                .get_filtered(
401                    self.client
402                        .issues(ORG, REPO)
403                        .list_labels_for_issue(pr_number)
404                        .per_page(PAGE_SIZE)
405                        .send()
406                        .await?,
407                    |pr_label| pr_label.name == label,
408                )
409                .await
410                .is_ok_and(|l| l.is_empty())
411            {
412                self.client
413                    .issues(ORG, REPO)
414                    .add_labels(pr_number, &[label.to_owned()])
415                    .await?;
416            }
417
418            Ok(())
419        }
420    }
421}
422
423#[cfg(feature = "octo-client")]
424pub use octo_client::OctocrabClient;