1use std::{borrow::Cow, collections::HashMap, fmt};
2
3use anyhow::Result;
4use derive_more::Deref;
5use serde::Deserialize;
6
7use crate::git::CommitSha;
8
9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
10
11#[derive(Debug, Clone)]
12pub struct GithubUser {
13 pub login: String,
14}
15
16#[derive(Debug, Clone)]
17pub struct PullRequestData {
18 pub number: u64,
19 pub user: Option<GithubUser>,
20 pub merged_by: Option<GithubUser>,
21 pub labels: Option<Vec<String>>,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum ReviewState {
26 Approved,
27 Other,
28}
29
30#[derive(Debug, Clone)]
31pub struct PullRequestReview {
32 pub user: Option<GithubUser>,
33 pub state: Option<ReviewState>,
34 pub body: Option<String>,
35}
36
37impl PullRequestReview {
38 pub fn with_body(self, body: impl ToString) -> Self {
39 Self {
40 body: Some(body.to_string()),
41 ..self
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub struct PullRequestComment {
48 pub user: GithubUser,
49 pub body: Option<String>,
50}
51
52#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
53pub struct GithubLogin {
54 login: String,
55}
56
57impl GithubLogin {
58 pub fn new(login: String) -> Self {
59 Self { login }
60 }
61}
62
63impl fmt::Display for GithubLogin {
64 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
65 write!(formatter, "@{}", self.login)
66 }
67}
68
69#[derive(Debug, Deserialize, Clone)]
70pub struct CommitAuthor {
71 name: String,
72 email: String,
73 user: Option<GithubLogin>,
74}
75
76impl CommitAuthor {
77 pub(crate) fn user(&self) -> Option<&GithubLogin> {
78 self.user.as_ref()
79 }
80}
81
82impl PartialEq for CommitAuthor {
83 fn eq(&self, other: &Self) -> bool {
84 self.user.as_ref().zip(other.user.as_ref()).map_or_else(
85 || self.email == other.email || self.name == other.name,
86 |(l, r)| l == r,
87 )
88 }
89}
90
91impl fmt::Display for CommitAuthor {
92 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
93 match self.user.as_ref() {
94 Some(user) => write!(formatter, "{} ({user})", self.name),
95 None => write!(formatter, "{} ({})", self.name, self.email),
96 }
97 }
98}
99
100#[derive(Debug, Deserialize, Clone)]
101pub struct CommitSignature {
102 #[serde(rename = "isValid")]
103 is_valid: bool,
104 signer: Option<GithubLogin>,
105}
106
107impl CommitSignature {
108 pub fn is_valid(&self) -> bool {
109 self.is_valid
110 }
111
112 pub fn signer(&self) -> Option<&GithubLogin> {
113 self.signer.as_ref()
114 }
115}
116
117#[derive(Debug, Deserialize)]
118pub struct CommitMetadata {
119 #[serde(rename = "author")]
120 primary_author: CommitAuthor,
121 #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
122 co_authors: Vec<CommitAuthor>,
123 #[serde(default)]
124 signature: Option<CommitSignature>,
125 #[serde(default)]
126 additions: u64,
127 #[serde(default)]
128 deletions: u64,
129}
130
131impl CommitMetadata {
132 pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
133 let mut co_authors = self
134 .co_authors
135 .iter()
136 .filter(|co_author| *co_author != &self.primary_author)
137 .peekable();
138
139 co_authors.peek().is_some().then_some(co_authors)
140 }
141
142 pub fn primary_author(&self) -> &CommitAuthor {
143 &self.primary_author
144 }
145
146 pub fn signature(&self) -> Option<&CommitSignature> {
147 self.signature.as_ref()
148 }
149
150 pub fn additions(&self) -> u64 {
151 self.additions
152 }
153
154 pub fn deletions(&self) -> u64 {
155 self.deletions
156 }
157}
158
159#[derive(Debug, Deref)]
160pub struct CommitMetadataBySha(HashMap<CommitSha, CommitMetadata>);
161
162impl CommitMetadataBySha {
163 const SHA_PREFIX: &'static str = "commit";
164}
165
166impl<'de> serde::Deserialize<'de> for CommitMetadataBySha {
167 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
168 where
169 D: serde::Deserializer<'de>,
170 {
171 let raw = HashMap::<String, CommitMetadata>::deserialize(deserializer)?;
172 let map = raw
173 .into_iter()
174 .map(|(key, value)| {
175 let sha = key
176 .strip_prefix(CommitMetadataBySha::SHA_PREFIX)
177 .unwrap_or(&key);
178 (CommitSha::new(sha.to_owned()), value)
179 })
180 .collect();
181 Ok(Self(map))
182 }
183}
184
185#[derive(Clone)]
186pub struct Repository<'a> {
187 owner: Cow<'a, str>,
188 name: Cow<'a, str>,
189}
190
191impl<'a> Repository<'a> {
192 pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
193
194 pub fn new(owner: &'a str, name: &'a str) -> Self {
195 Self {
196 owner: Cow::Borrowed(owner),
197 name: Cow::Borrowed(name),
198 }
199 }
200
201 pub fn owner(&self) -> &str {
202 &self.owner
203 }
204
205 pub fn name(&self) -> &str {
206 &self.name
207 }
208}
209
210impl Repository<'static> {
211 pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
212 Self {
213 owner: Cow::Borrowed(owner),
214 name: Cow::Borrowed(name),
215 }
216 }
217}
218
219#[async_trait::async_trait(?Send)]
220pub trait GithubApiClient {
221 async fn get_pull_request(
222 &self,
223 repo: &Repository<'_>,
224 pr_number: u64,
225 ) -> Result<PullRequestData>;
226 async fn get_pull_request_reviews(
227 &self,
228 repo: &Repository<'_>,
229 pr_number: u64,
230 ) -> Result<Vec<PullRequestReview>>;
231 async fn get_pull_request_comments(
232 &self,
233 repo: &Repository<'_>,
234 pr_number: u64,
235 ) -> Result<Vec<PullRequestComment>>;
236 async fn get_commit_metadata(
237 &self,
238 repo: &Repository<'_>,
239 commit_shas: &[&CommitSha],
240 ) -> Result<CommitMetadataBySha>;
241 async fn check_repo_write_permission(
242 &self,
243 repo: &Repository<'_>,
244 login: &GithubLogin,
245 ) -> Result<bool>;
246 async fn add_label_to_issue(
247 &self,
248 repo: &Repository<'_>,
249 label: &str,
250 issue_number: u64,
251 ) -> Result<()>;
252}
253
254pub mod graph_ql {
255 use anyhow::{Context as _, Result};
256 use itertools::Itertools as _;
257 use serde::Deserialize;
258
259 use crate::git::CommitSha;
260
261 use super::CommitMetadataBySha;
262
263 #[derive(Debug, Deserialize)]
264 pub struct GraphQLResponse<T> {
265 pub data: Option<T>,
266 pub errors: Option<Vec<GraphQLError>>,
267 }
268
269 impl<T> GraphQLResponse<T> {
270 pub fn into_data(self) -> Result<T> {
271 if let Some(errors) = &self.errors {
272 if !errors.is_empty() {
273 let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
274 anyhow::bail!("GraphQL error: {messages}");
275 }
276 }
277
278 self.data.context("GraphQL response contained no data")
279 }
280 }
281
282 #[derive(Debug, Deserialize)]
283 pub struct GraphQLError {
284 pub message: String,
285 }
286
287 #[derive(Debug, Deserialize)]
288 pub struct CommitMetadataResponse {
289 pub repository: CommitMetadataBySha,
290 }
291
292 pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
293 where
294 T: Deserialize<'de>,
295 D: serde::Deserializer<'de>,
296 {
297 #[derive(Deserialize)]
298 struct Nodes<T> {
299 nodes: Vec<T>,
300 }
301 Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
302 }
303
304 pub fn build_commit_metadata_query<'a>(
305 org: &str,
306 repo: &str,
307 shas: impl IntoIterator<Item = &'a CommitSha>,
308 ) -> String {
309 const FRAGMENT: &str = r#"
310 ... on Commit {
311 author {
312 name
313 email
314 user { login }
315 }
316 authors(first: 10) {
317 nodes {
318 name
319 email
320 user { login }
321 }
322 }
323 signature {
324 isValid
325 signer { login }
326 }
327 additions
328 deletions
329 }
330 "#;
331
332 let objects = shas
333 .into_iter()
334 .map(|commit_sha| {
335 format!(
336 "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
337 sha_prefix = CommitMetadataBySha::SHA_PREFIX,
338 sha = **commit_sha,
339 )
340 })
341 .join("\n");
342
343 format!("{{ repository(owner: \"{org}\", name: \"{repo}\") {{ {objects} }} }}")
344 .replace("\n", "")
345 }
346}
347
348#[cfg(feature = "octo-client")]
349mod octo_client {
350 use anyhow::{Context, Result};
351 use futures::TryStreamExt as _;
352 use jsonwebtoken::EncodingKey;
353 use octocrab::{
354 Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
355 service::middleware::cache::mem::InMemoryCache,
356 };
357 use serde::de::DeserializeOwned;
358 use tokio::pin;
359
360 use crate::{
361 git::CommitSha,
362 github::{Repository, graph_ql},
363 };
364
365 use super::{
366 CommitMetadataBySha, GithubApiClient, GithubLogin, GithubUser, PullRequestComment,
367 PullRequestData, PullRequestReview, ReviewState,
368 };
369
370 const PAGE_SIZE: u8 = 100;
371
372 pub struct OctocrabClient {
373 client: Octocrab,
374 }
375
376 impl OctocrabClient {
377 pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
378 let octocrab = Octocrab::builder()
379 .cache(InMemoryCache::new())
380 .app(
381 app_id.into(),
382 EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
383 )
384 .build()?;
385
386 let installations = octocrab
387 .apps()
388 .installations()
389 .send()
390 .await
391 .context("Failed to fetch installations")?
392 .take_items();
393
394 let installation_id = installations
395 .into_iter()
396 .find(|installation| installation.account.login == org)
397 .context("Could not find Zed repository in installations")?
398 .id;
399
400 let client = octocrab.installation(installation_id)?;
401 Ok(Self { client })
402 }
403
404 async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
405 let response: serde_json::Value = self.client.graphql(query).await?;
406 let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
407 .context("Failed to parse GraphQL response envelope")?;
408 parsed.into_data()
409 }
410
411 async fn get_all<T: DeserializeOwned + 'static>(
412 &self,
413 page: Page<T>,
414 ) -> octocrab::Result<Vec<T>> {
415 self.get_filtered(page, |_| true).await
416 }
417
418 async fn get_filtered<T: DeserializeOwned + 'static>(
419 &self,
420 page: Page<T>,
421 predicate: impl Fn(&T) -> bool,
422 ) -> octocrab::Result<Vec<T>> {
423 let stream = page.into_stream(&self.client);
424 pin!(stream);
425
426 let mut results = Vec::new();
427
428 while let Some(item) = stream.try_next().await?
429 && predicate(&item)
430 {
431 results.push(item);
432 }
433
434 Ok(results)
435 }
436 }
437
438 #[async_trait::async_trait(?Send)]
439 impl GithubApiClient for OctocrabClient {
440 async fn get_pull_request(
441 &self,
442 repo: &Repository<'_>,
443 pr_number: u64,
444 ) -> Result<PullRequestData> {
445 let pr = self
446 .client
447 .pulls(repo.owner.as_ref(), repo.name.as_ref())
448 .get(pr_number)
449 .await?;
450 Ok(PullRequestData {
451 number: pr.number,
452 user: pr.user.map(|user| GithubUser { login: user.login }),
453 merged_by: pr.merged_by.map(|user| GithubUser { login: user.login }),
454 labels: pr
455 .labels
456 .map(|labels| labels.into_iter().map(|label| label.name).collect()),
457 })
458 }
459
460 async fn get_pull_request_reviews(
461 &self,
462 repo: &Repository<'_>,
463 pr_number: u64,
464 ) -> Result<Vec<PullRequestReview>> {
465 let page = self
466 .client
467 .pulls(repo.owner.as_ref(), repo.name.as_ref())
468 .list_reviews(pr_number)
469 .per_page(PAGE_SIZE)
470 .send()
471 .await?;
472
473 let reviews = self.get_all(page).await?;
474
475 Ok(reviews
476 .into_iter()
477 .map(|review| PullRequestReview {
478 user: review.user.map(|user| GithubUser { login: user.login }),
479 state: review.state.map(|state| match state {
480 OctocrabReviewState::Approved => ReviewState::Approved,
481 _ => ReviewState::Other,
482 }),
483 body: review.body,
484 })
485 .collect())
486 }
487
488 async fn get_pull_request_comments(
489 &self,
490 repo: &Repository<'_>,
491 pr_number: u64,
492 ) -> Result<Vec<PullRequestComment>> {
493 let page = self
494 .client
495 .issues(repo.owner.as_ref(), repo.name.as_ref())
496 .list_comments(pr_number)
497 .per_page(PAGE_SIZE)
498 .send()
499 .await?;
500
501 let comments = self.get_all(page).await?;
502
503 Ok(comments
504 .into_iter()
505 .map(|comment| PullRequestComment {
506 user: GithubUser {
507 login: comment.user.login,
508 },
509 body: comment.body,
510 })
511 .collect())
512 }
513
514 async fn get_commit_metadata(
515 &self,
516 repo: &Repository<'_>,
517 commit_shas: &[&CommitSha],
518 ) -> Result<CommitMetadataBySha> {
519 let query = graph_ql::build_commit_metadata_query(
520 repo.owner.as_ref(),
521 repo.name.as_ref(),
522 commit_shas.iter().copied(),
523 );
524 let query = serde_json::json!({ "query": query });
525 self.graphql::<graph_ql::CommitMetadataResponse>(&query)
526 .await
527 .map(|response| response.repository)
528 }
529
530 async fn check_repo_write_permission(
531 &self,
532 repo: &Repository<'_>,
533 login: &GithubLogin,
534 ) -> Result<bool> {
535 // Check org membership first - we save ourselves a few request that way
536 let page = self
537 .client
538 .orgs(repo.owner.as_ref())
539 .list_members()
540 .per_page(PAGE_SIZE)
541 .send()
542 .await?;
543
544 let members = self.get_all(page).await?;
545
546 if members
547 .into_iter()
548 .any(|member| member.login == login.as_str())
549 {
550 return Ok(true);
551 }
552
553 // TODO: octocrab fails to deserialize the permission response and
554 // does not adhere to the scheme laid out at
555 // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
556
557 #[derive(serde::Deserialize)]
558 #[serde(rename_all = "lowercase")]
559 enum RepoPermission {
560 Admin,
561 Write,
562 Read,
563 #[serde(other)]
564 Other,
565 }
566
567 #[derive(serde::Deserialize)]
568 struct RepositoryPermissions {
569 permission: RepoPermission,
570 }
571
572 self.client
573 .get::<RepositoryPermissions, _, _>(
574 format!(
575 "/repos/{owner}/{repo}/collaborators/{user}/permission",
576 owner = repo.owner.as_ref(),
577 repo = repo.name.as_ref(),
578 user = login.as_str()
579 ),
580 None::<&()>,
581 )
582 .await
583 .map(|response| {
584 matches!(
585 response.permission,
586 RepoPermission::Write | RepoPermission::Admin
587 )
588 })
589 .map_err(Into::into)
590 }
591
592 async fn add_label_to_issue(
593 &self,
594 repo: &Repository<'_>,
595 label: &str,
596 issue_number: u64,
597 ) -> Result<()> {
598 self.client
599 .issues(repo.owner.as_ref(), repo.name.as_ref())
600 .add_labels(issue_number, &[label.to_owned()])
601 .await
602 .map(|_| ())
603 .map_err(Into::into)
604 }
605 }
606}
607
608#[cfg(feature = "octo-client")]
609pub use octo_client::OctocrabClient;