1use std::{borrow::Cow, collections::HashMap, fmt, ops::Not, rc::Rc};
2
3use anyhow::Result;
4use derive_more::Deref;
5use serde::Deserialize;
6
7use crate::git::CommitSha;
8
9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
10
11#[derive(Debug, Clone)]
12pub struct GithubUser {
13 pub login: String,
14}
15
16#[derive(Debug, Clone)]
17pub struct PullRequestData {
18 pub number: u64,
19 pub user: Option<GithubUser>,
20 pub merged_by: Option<GithubUser>,
21 pub labels: Option<Vec<String>>,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum ReviewState {
26 Approved,
27 Other,
28}
29
30#[derive(Debug, Clone)]
31pub struct PullRequestReview {
32 pub user: Option<GithubUser>,
33 pub state: Option<ReviewState>,
34 pub body: Option<String>,
35}
36
37impl PullRequestReview {
38 pub fn with_body(self, body: impl ToString) -> Self {
39 Self {
40 body: Some(body.to_string()),
41 ..self
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub struct PullRequestComment {
48 pub user: GithubUser,
49 pub body: Option<String>,
50}
51
52#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
53pub struct GithubLogin {
54 login: String,
55}
56
57impl GithubLogin {
58 pub fn new(login: String) -> Self {
59 Self { login }
60 }
61}
62
63impl fmt::Display for GithubLogin {
64 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
65 write!(formatter, "@{}", self.login)
66 }
67}
68
69#[derive(Debug, Deserialize, Clone)]
70pub struct CommitAuthor {
71 name: String,
72 email: String,
73 user: Option<GithubLogin>,
74}
75
76impl CommitAuthor {
77 pub(crate) fn user(&self) -> Option<&GithubLogin> {
78 self.user.as_ref()
79 }
80}
81
82impl PartialEq for CommitAuthor {
83 fn eq(&self, other: &Self) -> bool {
84 self.user.as_ref().zip(other.user.as_ref()).map_or_else(
85 || self.email == other.email || self.name == other.name,
86 |(l, r)| l == r,
87 )
88 }
89}
90
91impl fmt::Display for CommitAuthor {
92 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
93 match self.user.as_ref() {
94 Some(user) => write!(formatter, "{} ({user})", self.name),
95 None => write!(formatter, "{} ({})", self.name, self.email),
96 }
97 }
98}
99
100#[derive(Debug, Deserialize)]
101pub struct CommitAuthors {
102 #[serde(rename = "author")]
103 primary_author: CommitAuthor,
104 #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
105 co_authors: Vec<CommitAuthor>,
106}
107
108impl CommitAuthors {
109 pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
110 self.co_authors.is_empty().not().then(|| {
111 self.co_authors
112 .iter()
113 .filter(|co_author| *co_author != &self.primary_author)
114 })
115 }
116}
117
118#[derive(Debug, Deref)]
119pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
120
121impl AuthorsForCommits {
122 const SHA_PREFIX: &'static str = "commit";
123}
124
125impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
126 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
127 where
128 D: serde::Deserializer<'de>,
129 {
130 let raw = HashMap::<String, CommitAuthors>::deserialize(deserializer)?;
131 let map = raw
132 .into_iter()
133 .map(|(key, value)| {
134 let sha = key
135 .strip_prefix(AuthorsForCommits::SHA_PREFIX)
136 .unwrap_or(&key);
137 (CommitSha::new(sha.to_owned()), value)
138 })
139 .collect();
140 Ok(Self(map))
141 }
142}
143
144#[derive(Clone)]
145pub struct Repository<'a> {
146 owner: Cow<'a, str>,
147 name: Cow<'a, str>,
148}
149
150impl<'a> Repository<'a> {
151 pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
152
153 pub fn new(owner: &'a str, name: &'a str) -> Self {
154 Self {
155 owner: Cow::Borrowed(owner),
156 name: Cow::Borrowed(name),
157 }
158 }
159
160 pub fn owner(&self) -> &str {
161 &self.owner
162 }
163
164 pub fn name(&self) -> &str {
165 &self.name
166 }
167}
168
169impl Repository<'static> {
170 pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
171 Self {
172 owner: Cow::Borrowed(owner),
173 name: Cow::Borrowed(name),
174 }
175 }
176}
177
178#[async_trait::async_trait(?Send)]
179pub trait GithubApiClient {
180 async fn get_pull_request(
181 &self,
182 repo: &Repository<'_>,
183 pr_number: u64,
184 ) -> Result<PullRequestData>;
185 async fn get_pull_request_reviews(
186 &self,
187 repo: &Repository<'_>,
188 pr_number: u64,
189 ) -> Result<Vec<PullRequestReview>>;
190 async fn get_pull_request_comments(
191 &self,
192 repo: &Repository<'_>,
193 pr_number: u64,
194 ) -> Result<Vec<PullRequestComment>>;
195 async fn get_commit_authors(
196 &self,
197 repo: &Repository<'_>,
198 commit_shas: &[&CommitSha],
199 ) -> Result<AuthorsForCommits>;
200 async fn check_repo_write_permission(
201 &self,
202 repo: &Repository<'_>,
203 login: &GithubLogin,
204 ) -> Result<bool>;
205 async fn add_label_to_issue(
206 &self,
207 repo: &Repository<'_>,
208 label: &str,
209 issue_number: u64,
210 ) -> Result<()>;
211}
212
213#[derive(Deref)]
214pub struct GithubClient {
215 api: Rc<dyn GithubApiClient>,
216}
217
218impl GithubClient {
219 pub fn new(api: Rc<dyn GithubApiClient>) -> Self {
220 Self { api }
221 }
222
223 #[cfg(feature = "octo-client")]
224 pub async fn for_app_in_repo(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
225 let client = OctocrabClient::new(app_id, app_private_key, org).await?;
226 Ok(Self::new(Rc::new(client)))
227 }
228}
229
230pub mod graph_ql {
231 use anyhow::{Context as _, Result};
232 use itertools::Itertools as _;
233 use serde::Deserialize;
234
235 use crate::git::CommitSha;
236
237 use super::AuthorsForCommits;
238
239 #[derive(Debug, Deserialize)]
240 pub struct GraphQLResponse<T> {
241 pub data: Option<T>,
242 pub errors: Option<Vec<GraphQLError>>,
243 }
244
245 impl<T> GraphQLResponse<T> {
246 pub fn into_data(self) -> Result<T> {
247 if let Some(errors) = &self.errors {
248 if !errors.is_empty() {
249 let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
250 anyhow::bail!("GraphQL error: {messages}");
251 }
252 }
253
254 self.data.context("GraphQL response contained no data")
255 }
256 }
257
258 #[derive(Debug, Deserialize)]
259 pub struct GraphQLError {
260 pub message: String,
261 }
262
263 #[derive(Debug, Deserialize)]
264 pub struct CommitAuthorsResponse {
265 pub repository: AuthorsForCommits,
266 }
267
268 pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
269 where
270 T: Deserialize<'de>,
271 D: serde::Deserializer<'de>,
272 {
273 #[derive(Deserialize)]
274 struct Nodes<T> {
275 nodes: Vec<T>,
276 }
277 Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
278 }
279
280 pub fn build_co_authors_query<'a>(
281 org: &str,
282 repo: &str,
283 shas: impl IntoIterator<Item = &'a CommitSha>,
284 ) -> String {
285 const FRAGMENT: &str = r#"
286 ... on Commit {
287 author {
288 name
289 email
290 user { login }
291 }
292 authors(first: 10) {
293 nodes {
294 name
295 email
296 user { login }
297 }
298 }
299 }
300 "#;
301
302 let objects = shas
303 .into_iter()
304 .map(|commit_sha| {
305 format!(
306 "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
307 sha_prefix = AuthorsForCommits::SHA_PREFIX,
308 sha = **commit_sha,
309 )
310 })
311 .join("\n");
312
313 format!("{{ repository(owner: \"{org}\", name: \"{repo}\") {{ {objects} }} }}")
314 .replace("\n", "")
315 }
316}
317
318#[cfg(feature = "octo-client")]
319mod octo_client {
320 use anyhow::{Context, Result};
321 use futures::TryStreamExt as _;
322 use jsonwebtoken::EncodingKey;
323 use octocrab::{
324 Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
325 service::middleware::cache::mem::InMemoryCache,
326 };
327 use serde::de::DeserializeOwned;
328 use tokio::pin;
329
330 use crate::{
331 git::CommitSha,
332 github::{Repository, graph_ql},
333 };
334
335 use super::{
336 AuthorsForCommits, GithubApiClient, GithubLogin, GithubUser, PullRequestComment,
337 PullRequestData, PullRequestReview, ReviewState,
338 };
339
340 const PAGE_SIZE: u8 = 100;
341
342 pub struct OctocrabClient {
343 client: Octocrab,
344 }
345
346 impl OctocrabClient {
347 pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
348 let octocrab = Octocrab::builder()
349 .cache(InMemoryCache::new())
350 .app(
351 app_id.into(),
352 EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
353 )
354 .build()?;
355
356 let installations = octocrab
357 .apps()
358 .installations()
359 .send()
360 .await
361 .context("Failed to fetch installations")?
362 .take_items();
363
364 let installation_id = installations
365 .into_iter()
366 .find(|installation| installation.account.login == org)
367 .context("Could not find Zed repository in installations")?
368 .id;
369
370 let client = octocrab.installation(installation_id)?;
371 Ok(Self { client })
372 }
373
374 async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
375 let response: serde_json::Value = self.client.graphql(query).await?;
376 let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
377 .context("Failed to parse GraphQL response envelope")?;
378 parsed.into_data()
379 }
380
381 async fn get_all<T: DeserializeOwned + 'static>(
382 &self,
383 page: Page<T>,
384 ) -> octocrab::Result<Vec<T>> {
385 self.get_filtered(page, |_| true).await
386 }
387
388 async fn get_filtered<T: DeserializeOwned + 'static>(
389 &self,
390 page: Page<T>,
391 predicate: impl Fn(&T) -> bool,
392 ) -> octocrab::Result<Vec<T>> {
393 let stream = page.into_stream(&self.client);
394 pin!(stream);
395
396 let mut results = Vec::new();
397
398 while let Some(item) = stream.try_next().await?
399 && predicate(&item)
400 {
401 results.push(item);
402 }
403
404 Ok(results)
405 }
406 }
407
408 #[async_trait::async_trait(?Send)]
409 impl GithubApiClient for OctocrabClient {
410 async fn get_pull_request(
411 &self,
412 repo: &Repository<'_>,
413 pr_number: u64,
414 ) -> Result<PullRequestData> {
415 let pr = self
416 .client
417 .pulls(repo.owner.as_ref(), repo.name.as_ref())
418 .get(pr_number)
419 .await?;
420 Ok(PullRequestData {
421 number: pr.number,
422 user: pr.user.map(|user| GithubUser { login: user.login }),
423 merged_by: pr.merged_by.map(|user| GithubUser { login: user.login }),
424 labels: pr
425 .labels
426 .map(|labels| labels.into_iter().map(|label| label.name).collect()),
427 })
428 }
429
430 async fn get_pull_request_reviews(
431 &self,
432 repo: &Repository<'_>,
433 pr_number: u64,
434 ) -> Result<Vec<PullRequestReview>> {
435 let page = self
436 .client
437 .pulls(repo.owner.as_ref(), repo.name.as_ref())
438 .list_reviews(pr_number)
439 .per_page(PAGE_SIZE)
440 .send()
441 .await?;
442
443 let reviews = self.get_all(page).await?;
444
445 Ok(reviews
446 .into_iter()
447 .map(|review| PullRequestReview {
448 user: review.user.map(|user| GithubUser { login: user.login }),
449 state: review.state.map(|state| match state {
450 OctocrabReviewState::Approved => ReviewState::Approved,
451 _ => ReviewState::Other,
452 }),
453 body: review.body,
454 })
455 .collect())
456 }
457
458 async fn get_pull_request_comments(
459 &self,
460 repo: &Repository<'_>,
461 pr_number: u64,
462 ) -> Result<Vec<PullRequestComment>> {
463 let page = self
464 .client
465 .issues(repo.owner.as_ref(), repo.name.as_ref())
466 .list_comments(pr_number)
467 .per_page(PAGE_SIZE)
468 .send()
469 .await?;
470
471 let comments = self.get_all(page).await?;
472
473 Ok(comments
474 .into_iter()
475 .map(|comment| PullRequestComment {
476 user: GithubUser {
477 login: comment.user.login,
478 },
479 body: comment.body,
480 })
481 .collect())
482 }
483
484 async fn get_commit_authors(
485 &self,
486 repo: &Repository<'_>,
487 commit_shas: &[&CommitSha],
488 ) -> Result<AuthorsForCommits> {
489 let query = graph_ql::build_co_authors_query(
490 repo.owner.as_ref(),
491 repo.name.as_ref(),
492 commit_shas.iter().copied(),
493 );
494 let query = serde_json::json!({ "query": query });
495 self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
496 .await
497 .map(|response| response.repository)
498 }
499
500 async fn check_repo_write_permission(
501 &self,
502 repo: &Repository<'_>,
503 login: &GithubLogin,
504 ) -> Result<bool> {
505 // Check org membership first - we save ourselves a few request that way
506 let page = self
507 .client
508 .orgs(repo.owner.as_ref())
509 .list_members()
510 .per_page(PAGE_SIZE)
511 .send()
512 .await?;
513
514 let members = self.get_all(page).await?;
515
516 if members
517 .into_iter()
518 .any(|member| member.login == login.as_str())
519 {
520 return Ok(true);
521 }
522
523 // TODO: octocrab fails to deserialize the permission response and
524 // does not adhere to the scheme laid out at
525 // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
526
527 #[derive(serde::Deserialize)]
528 #[serde(rename_all = "lowercase")]
529 enum RepoPermission {
530 Admin,
531 Write,
532 Read,
533 #[serde(other)]
534 Other,
535 }
536
537 #[derive(serde::Deserialize)]
538 struct RepositoryPermissions {
539 permission: RepoPermission,
540 }
541
542 self.client
543 .get::<RepositoryPermissions, _, _>(
544 format!(
545 "/repos/{owner}/{repo}/collaborators/{user}/permission",
546 owner = repo.owner.as_ref(),
547 repo = repo.name.as_ref(),
548 user = login.as_str()
549 ),
550 None::<&()>,
551 )
552 .await
553 .map(|response| {
554 matches!(
555 response.permission,
556 RepoPermission::Write | RepoPermission::Admin
557 )
558 })
559 .map_err(Into::into)
560 }
561
562 async fn add_label_to_issue(
563 &self,
564 repo: &Repository<'_>,
565 label: &str,
566 issue_number: u64,
567 ) -> Result<()> {
568 self.client
569 .issues(repo.owner.as_ref(), repo.name.as_ref())
570 .add_labels(issue_number, &[label.to_owned()])
571 .await
572 .map(|_| ())
573 .map_err(Into::into)
574 }
575 }
576}
577
578#[cfg(feature = "octo-client")]
579pub use octo_client::OctocrabClient;