1use std::{borrow::Cow, collections::HashMap, fmt};
2
3use anyhow::Result;
4use derive_more::Deref;
5use serde::Deserialize;
6
7use crate::git::CommitSha;
8
9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
10
11#[derive(Debug, Clone)]
12pub struct GithubUser {
13 pub login: String,
14}
15
16#[derive(Debug, Clone)]
17pub struct PullRequestData {
18 pub number: u64,
19 pub user: Option<GithubUser>,
20 pub merged_by: Option<GithubUser>,
21 pub labels: Option<Vec<String>>,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum ReviewState {
26 Approved,
27 Other,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum AuthorAssociation {
32 Owner,
33 Member,
34 Collaborator,
35 Contributor,
36 FirstTimeContributor,
37 FirstTimer,
38 Mannequin,
39 None,
40}
41
42impl AuthorAssociation {
43 pub fn has_write_access(&self) -> bool {
44 matches!(self, Self::Owner | Self::Member | Self::Collaborator)
45 }
46}
47
48pub trait Approvable {
49 fn author_login(&self) -> Option<&str>;
50 fn review_state(&self) -> Option<ReviewState>;
51 fn body(&self) -> Option<&str>;
52 fn author_association(&self) -> Option<AuthorAssociation>;
53}
54
55#[derive(Debug, Clone)]
56pub struct PullRequestReview {
57 pub user: Option<GithubUser>,
58 pub state: Option<ReviewState>,
59 pub body: Option<String>,
60 pub author_association: Option<AuthorAssociation>,
61}
62
63impl PullRequestReview {
64 pub fn with_body(self, body: impl ToString) -> Self {
65 Self {
66 body: Some(body.to_string()),
67 ..self
68 }
69 }
70}
71
72impl Approvable for PullRequestReview {
73 fn author_login(&self) -> Option<&str> {
74 self.user.as_ref().map(|user| user.login.as_str())
75 }
76
77 fn review_state(&self) -> Option<ReviewState> {
78 self.state
79 }
80
81 fn body(&self) -> Option<&str> {
82 self.body.as_deref()
83 }
84
85 fn author_association(&self) -> Option<AuthorAssociation> {
86 self.author_association
87 }
88}
89
90#[derive(Debug, Clone)]
91pub struct PullRequestComment {
92 pub user: GithubUser,
93 pub body: Option<String>,
94 pub author_association: Option<AuthorAssociation>,
95}
96
97impl Approvable for PullRequestComment {
98 fn author_login(&self) -> Option<&str> {
99 Some(&self.user.login)
100 }
101
102 fn review_state(&self) -> Option<ReviewState> {
103 None
104 }
105
106 fn body(&self) -> Option<&str> {
107 self.body.as_deref()
108 }
109
110 fn author_association(&self) -> Option<AuthorAssociation> {
111 self.author_association
112 }
113}
114
115#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
116pub struct GithubLogin {
117 login: String,
118}
119
120impl GithubLogin {
121 pub fn new(login: String) -> Self {
122 Self { login }
123 }
124}
125
126impl fmt::Display for GithubLogin {
127 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
128 write!(formatter, "@{}", self.login)
129 }
130}
131
132#[derive(Debug, Deserialize, Clone)]
133pub struct CommitAuthor {
134 name: String,
135 email: String,
136 user: Option<GithubLogin>,
137}
138
139impl CommitAuthor {
140 pub(crate) fn user(&self) -> Option<&GithubLogin> {
141 self.user.as_ref()
142 }
143}
144
145impl PartialEq for CommitAuthor {
146 fn eq(&self, other: &Self) -> bool {
147 self.user.as_ref().zip(other.user.as_ref()).map_or_else(
148 || self.email == other.email || self.name == other.name,
149 |(l, r)| l == r,
150 )
151 }
152}
153
154impl fmt::Display for CommitAuthor {
155 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
156 match self.user.as_ref() {
157 Some(user) => write!(formatter, "{} ({user})", self.name),
158 None => write!(formatter, "{} ({})", self.name, self.email),
159 }
160 }
161}
162
163#[derive(Debug, Deserialize, Clone)]
164pub struct CommitSignature {
165 #[serde(rename = "isValid")]
166 is_valid: bool,
167 signer: Option<GithubLogin>,
168}
169
170impl CommitSignature {
171 pub fn is_valid(&self) -> bool {
172 self.is_valid
173 }
174
175 pub fn signer(&self) -> Option<&GithubLogin> {
176 self.signer.as_ref()
177 }
178}
179
180#[derive(Debug, Clone, Deserialize)]
181pub struct CommitFileChange {
182 pub filename: String,
183}
184
185#[derive(Debug, Deserialize)]
186pub struct CommitMetadata {
187 #[serde(rename = "author")]
188 primary_author: CommitAuthor,
189 #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
190 co_authors: Vec<CommitAuthor>,
191 #[serde(default)]
192 signature: Option<CommitSignature>,
193 #[serde(default)]
194 additions: u64,
195 #[serde(default)]
196 deletions: u64,
197}
198
199impl CommitMetadata {
200 pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
201 let mut co_authors = self
202 .co_authors
203 .iter()
204 .filter(|co_author| *co_author != &self.primary_author)
205 .peekable();
206
207 co_authors.peek().is_some().then_some(co_authors)
208 }
209
210 pub fn primary_author(&self) -> &CommitAuthor {
211 &self.primary_author
212 }
213
214 pub fn signature(&self) -> Option<&CommitSignature> {
215 self.signature.as_ref()
216 }
217
218 pub fn additions(&self) -> u64 {
219 self.additions
220 }
221
222 pub fn deletions(&self) -> u64 {
223 self.deletions
224 }
225}
226
227#[derive(Debug, Deref)]
228pub struct CommitMetadataBySha(HashMap<CommitSha, CommitMetadata>);
229
230impl CommitMetadataBySha {
231 const SHA_PREFIX: &'static str = "commit";
232}
233
234impl<'de> serde::Deserialize<'de> for CommitMetadataBySha {
235 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
236 where
237 D: serde::Deserializer<'de>,
238 {
239 let raw = HashMap::<String, CommitMetadata>::deserialize(deserializer)?;
240 let map = raw
241 .into_iter()
242 .map(|(key, value)| {
243 let sha = key
244 .strip_prefix(CommitMetadataBySha::SHA_PREFIX)
245 .unwrap_or(&key);
246 (CommitSha::new(sha.to_owned()), value)
247 })
248 .collect();
249 Ok(Self(map))
250 }
251}
252
253#[derive(Clone)]
254pub struct Repository<'a> {
255 owner: Cow<'a, str>,
256 name: Cow<'a, str>,
257}
258
259impl<'a> Repository<'a> {
260 pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
261
262 pub fn new(owner: &'a str, name: &'a str) -> Self {
263 Self {
264 owner: Cow::Borrowed(owner),
265 name: Cow::Borrowed(name),
266 }
267 }
268
269 pub fn owner(&self) -> &str {
270 &self.owner
271 }
272
273 pub fn name(&self) -> &str {
274 &self.name
275 }
276}
277
278impl Repository<'static> {
279 pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
280 Self {
281 owner: Cow::Borrowed(owner),
282 name: Cow::Borrowed(name),
283 }
284 }
285}
286
287#[async_trait::async_trait(?Send)]
288pub trait GithubApiClient {
289 async fn get_pull_request(
290 &self,
291 repo: &Repository<'_>,
292 pr_number: u64,
293 ) -> Result<PullRequestData>;
294 async fn get_pull_request_reviews(
295 &self,
296 repo: &Repository<'_>,
297 pr_number: u64,
298 ) -> Result<Vec<PullRequestReview>>;
299 async fn get_pull_request_comments(
300 &self,
301 repo: &Repository<'_>,
302 pr_number: u64,
303 ) -> Result<Vec<PullRequestComment>>;
304 async fn get_commit_metadata(
305 &self,
306 repo: &Repository<'_>,
307 commit_shas: &[&CommitSha],
308 ) -> Result<CommitMetadataBySha>;
309 async fn get_commit_files(
310 &self,
311 repo: &Repository<'_>,
312 sha: &CommitSha,
313 ) -> Result<Vec<CommitFileChange>>;
314 async fn check_repo_write_permission(
315 &self,
316 repo: &Repository<'_>,
317 login: &GithubLogin,
318 ) -> Result<bool>;
319 async fn add_label_to_issue(
320 &self,
321 repo: &Repository<'_>,
322 label: &str,
323 issue_number: u64,
324 ) -> Result<()>;
325}
326
327pub mod graph_ql {
328 use anyhow::{Context as _, Result};
329 use itertools::Itertools as _;
330 use serde::Deserialize;
331
332 use crate::git::CommitSha;
333
334 use super::CommitMetadataBySha;
335
336 #[derive(Debug, Deserialize)]
337 pub struct GraphQLResponse<T> {
338 pub data: Option<T>,
339 pub errors: Option<Vec<GraphQLError>>,
340 }
341
342 impl<T> GraphQLResponse<T> {
343 pub fn into_data(self) -> Result<T> {
344 if let Some(errors) = &self.errors {
345 if !errors.is_empty() {
346 let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
347 anyhow::bail!("GraphQL error: {messages}");
348 }
349 }
350
351 self.data.context("GraphQL response contained no data")
352 }
353 }
354
355 #[derive(Debug, Deserialize)]
356 pub struct GraphQLError {
357 pub message: String,
358 }
359
360 #[derive(Debug, Deserialize)]
361 pub struct CommitMetadataResponse {
362 pub repository: CommitMetadataBySha,
363 }
364
365 pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
366 where
367 T: Deserialize<'de>,
368 D: serde::Deserializer<'de>,
369 {
370 #[derive(Deserialize)]
371 struct Nodes<T> {
372 nodes: Vec<T>,
373 }
374 Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
375 }
376
377 pub fn build_commit_metadata_query<'a>(
378 org: &str,
379 repo: &str,
380 shas: impl IntoIterator<Item = &'a CommitSha>,
381 ) -> String {
382 const FRAGMENT: &str = r#"
383 ... on Commit {
384 author {
385 name
386 email
387 user { login }
388 }
389 authors(first: 10) {
390 nodes {
391 name
392 email
393 user { login }
394 }
395 }
396 signature {
397 isValid
398 signer { login }
399 }
400 additions
401 deletions
402 }
403 "#;
404
405 let objects = shas
406 .into_iter()
407 .map(|commit_sha| {
408 format!(
409 "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
410 sha_prefix = CommitMetadataBySha::SHA_PREFIX,
411 sha = **commit_sha,
412 )
413 })
414 .join("\n");
415
416 format!("{{ repository(owner: \"{org}\", name: \"{repo}\") {{ {objects} }} }}")
417 .replace("\n", "")
418 }
419}
420
421#[cfg(feature = "octo-client")]
422mod octo_client {
423 use anyhow::{Context, Result};
424 use futures::TryStreamExt as _;
425 use jsonwebtoken::EncodingKey;
426 use octocrab::{
427 Octocrab, Page,
428 models::{
429 AuthorAssociation as OctocrabAuthorAssociation,
430 pulls::ReviewState as OctocrabReviewState,
431 },
432 service::middleware::cache::mem::InMemoryCache,
433 };
434 use serde::de::DeserializeOwned;
435 use tokio::pin;
436
437 use crate::{
438 git::CommitSha,
439 github::{Repository, graph_ql},
440 };
441
442 use super::{
443 AuthorAssociation, CommitFileChange, CommitMetadataBySha, GithubApiClient, GithubLogin,
444 GithubUser, PullRequestComment, PullRequestData, PullRequestReview, ReviewState,
445 };
446
447 fn convert_author_association(association: OctocrabAuthorAssociation) -> AuthorAssociation {
448 match association {
449 OctocrabAuthorAssociation::Owner => AuthorAssociation::Owner,
450 OctocrabAuthorAssociation::Member => AuthorAssociation::Member,
451 OctocrabAuthorAssociation::Collaborator => AuthorAssociation::Collaborator,
452 OctocrabAuthorAssociation::Contributor => AuthorAssociation::Contributor,
453 OctocrabAuthorAssociation::FirstTimeContributor => {
454 AuthorAssociation::FirstTimeContributor
455 }
456 OctocrabAuthorAssociation::FirstTimer => AuthorAssociation::FirstTimer,
457 OctocrabAuthorAssociation::Mannequin => AuthorAssociation::Mannequin,
458 OctocrabAuthorAssociation::None => AuthorAssociation::None,
459 _ => AuthorAssociation::None,
460 }
461 }
462
463 const PAGE_SIZE: u8 = 100;
464
465 pub struct OctocrabClient {
466 client: Octocrab,
467 }
468
469 impl OctocrabClient {
470 pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
471 let octocrab = Octocrab::builder()
472 .cache(InMemoryCache::new())
473 .app(
474 app_id.into(),
475 EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
476 )
477 .build()?;
478
479 let installations = octocrab
480 .apps()
481 .installations()
482 .send()
483 .await
484 .context("Failed to fetch installations")?
485 .take_items();
486
487 let installation_id = installations
488 .into_iter()
489 .find(|installation| installation.account.login == org)
490 .context("Could not find Zed repository in installations")?
491 .id;
492
493 let client = octocrab.installation(installation_id)?;
494 Ok(Self { client })
495 }
496
497 async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
498 let response: serde_json::Value = self.client.graphql(query).await?;
499 let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
500 .context("Failed to parse GraphQL response envelope")?;
501 parsed.into_data()
502 }
503
504 async fn get_all<T: DeserializeOwned + 'static>(
505 &self,
506 page: Page<T>,
507 ) -> octocrab::Result<Vec<T>> {
508 self.get_filtered(page, |_| true).await
509 }
510
511 async fn get_filtered<T: DeserializeOwned + 'static>(
512 &self,
513 page: Page<T>,
514 predicate: impl Fn(&T) -> bool,
515 ) -> octocrab::Result<Vec<T>> {
516 let stream = page.into_stream(&self.client);
517 pin!(stream);
518
519 let mut results = Vec::new();
520
521 while let Some(item) = stream.try_next().await?
522 && predicate(&item)
523 {
524 results.push(item);
525 }
526
527 Ok(results)
528 }
529 }
530
531 #[async_trait::async_trait(?Send)]
532 impl GithubApiClient for OctocrabClient {
533 async fn get_pull_request(
534 &self,
535 repo: &Repository<'_>,
536 pr_number: u64,
537 ) -> Result<PullRequestData> {
538 let pr = self
539 .client
540 .pulls(repo.owner.as_ref(), repo.name.as_ref())
541 .get(pr_number)
542 .await?;
543 Ok(PullRequestData {
544 number: pr.number,
545 user: pr.user.map(|user| GithubUser { login: user.login }),
546 merged_by: pr.merged_by.map(|user| GithubUser { login: user.login }),
547 labels: pr
548 .labels
549 .map(|labels| labels.into_iter().map(|label| label.name).collect()),
550 })
551 }
552
553 async fn get_pull_request_reviews(
554 &self,
555 repo: &Repository<'_>,
556 pr_number: u64,
557 ) -> Result<Vec<PullRequestReview>> {
558 let page = self
559 .client
560 .pulls(repo.owner.as_ref(), repo.name.as_ref())
561 .list_reviews(pr_number)
562 .per_page(PAGE_SIZE)
563 .send()
564 .await?;
565
566 let reviews = self.get_all(page).await?;
567
568 Ok(reviews
569 .into_iter()
570 .map(|review| PullRequestReview {
571 user: review.user.map(|user| GithubUser { login: user.login }),
572 state: review.state.map(|state| match state {
573 OctocrabReviewState::Approved => ReviewState::Approved,
574 _ => ReviewState::Other,
575 }),
576 body: review.body,
577 author_association: review.author_association.map(convert_author_association),
578 })
579 .collect())
580 }
581
582 async fn get_pull_request_comments(
583 &self,
584 repo: &Repository<'_>,
585 pr_number: u64,
586 ) -> Result<Vec<PullRequestComment>> {
587 let page = self
588 .client
589 .issues(repo.owner.as_ref(), repo.name.as_ref())
590 .list_comments(pr_number)
591 .per_page(PAGE_SIZE)
592 .send()
593 .await?;
594
595 let comments = self.get_all(page).await?;
596
597 Ok(comments
598 .into_iter()
599 .map(|comment| PullRequestComment {
600 user: GithubUser {
601 login: comment.user.login,
602 },
603 body: comment.body,
604 author_association: comment.author_association.map(convert_author_association),
605 })
606 .collect())
607 }
608
609 async fn get_commit_metadata(
610 &self,
611 repo: &Repository<'_>,
612 commit_shas: &[&CommitSha],
613 ) -> Result<CommitMetadataBySha> {
614 let query = graph_ql::build_commit_metadata_query(
615 repo.owner.as_ref(),
616 repo.name.as_ref(),
617 commit_shas.iter().copied(),
618 );
619 let query = serde_json::json!({ "query": query });
620 self.graphql::<graph_ql::CommitMetadataResponse>(&query)
621 .await
622 .map(|response| response.repository)
623 }
624
625 async fn get_commit_files(
626 &self,
627 repo: &Repository<'_>,
628 sha: &CommitSha,
629 ) -> Result<Vec<CommitFileChange>> {
630 let response = self
631 .client
632 .commits(repo.owner.as_ref(), repo.name.as_ref())
633 .get(sha.as_str())
634 .await?;
635
636 Ok(response
637 .files
638 .into_iter()
639 .flatten()
640 .map(|file| CommitFileChange {
641 filename: file.filename,
642 })
643 .collect())
644 }
645
646 async fn check_repo_write_permission(
647 &self,
648 repo: &Repository<'_>,
649 login: &GithubLogin,
650 ) -> Result<bool> {
651 // Check org membership first - we save ourselves a few request that way
652 let page = self
653 .client
654 .orgs(repo.owner.as_ref())
655 .list_members()
656 .per_page(PAGE_SIZE)
657 .send()
658 .await?;
659
660 let members = self.get_all(page).await?;
661
662 if members
663 .into_iter()
664 .any(|member| member.login == login.as_str())
665 {
666 return Ok(true);
667 }
668
669 // TODO: octocrab fails to deserialize the permission response and
670 // does not adhere to the scheme laid out at
671 // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
672
673 #[derive(serde::Deserialize)]
674 #[serde(rename_all = "lowercase")]
675 enum RepoPermission {
676 Admin,
677 Write,
678 Read,
679 #[serde(other)]
680 Other,
681 }
682
683 #[derive(serde::Deserialize)]
684 struct RepositoryPermissions {
685 permission: RepoPermission,
686 }
687
688 self.client
689 .get::<RepositoryPermissions, _, _>(
690 format!(
691 "/repos/{owner}/{repo}/collaborators/{user}/permission",
692 owner = repo.owner.as_ref(),
693 repo = repo.name.as_ref(),
694 user = login.as_str()
695 ),
696 None::<&()>,
697 )
698 .await
699 .map(|response| {
700 matches!(
701 response.permission,
702 RepoPermission::Write | RepoPermission::Admin
703 )
704 })
705 .map_err(Into::into)
706 }
707
708 async fn add_label_to_issue(
709 &self,
710 repo: &Repository<'_>,
711 label: &str,
712 issue_number: u64,
713 ) -> Result<()> {
714 self.client
715 .issues(repo.owner.as_ref(), repo.name.as_ref())
716 .add_labels(issue_number, &[label.to_owned()])
717 .await
718 .map(|_| ())
719 .map_err(Into::into)
720 }
721 }
722}
723
724#[cfg(feature = "octo-client")]
725pub use octo_client::OctocrabClient;