1use std::{borrow::Cow, collections::HashMap, fmt, ops::Not, rc::Rc, sync::LazyLock};
2
3use anyhow::Result;
4use derive_more::Deref;
5use serde::Deserialize;
6
7use crate::git::CommitSha;
8
9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
10
11#[derive(Debug, Clone)]
12pub struct GithubUser {
13 pub login: String,
14}
15
16#[derive(Debug, Clone)]
17pub struct PullRequestData {
18 pub number: u64,
19 pub user: Option<GithubUser>,
20 pub merged_by: Option<GithubUser>,
21 pub labels: Option<Vec<String>>,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum ReviewState {
26 Approved,
27 Other,
28}
29
30#[derive(Debug, Clone)]
31pub struct PullRequestReview {
32 pub user: Option<GithubUser>,
33 pub state: Option<ReviewState>,
34 pub body: Option<String>,
35}
36
37impl PullRequestReview {
38 pub fn with_body(self, body: impl ToString) -> Self {
39 Self {
40 body: Some(body.to_string()),
41 ..self
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub struct PullRequestComment {
48 pub user: GithubUser,
49 pub body: Option<String>,
50}
51
52#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
53pub struct GithubLogin {
54 login: String,
55}
56
57impl GithubLogin {
58 pub fn new(login: String) -> Self {
59 Self { login }
60 }
61}
62
63impl fmt::Display for GithubLogin {
64 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
65 write!(formatter, "@{}", self.login)
66 }
67}
68
69#[derive(Debug, Deserialize, Clone)]
70pub struct CommitAuthor {
71 name: String,
72 email: String,
73 user: Option<GithubLogin>,
74}
75
76pub(crate) static ZED_ZIPPY_AUTHOR: LazyLock<CommitAuthor> = LazyLock::new(|| CommitAuthor {
77 name: "Zed Zippy".to_string(),
78 email: "234243425+zed-zippy[bot]@users.noreply.github.com".to_string(),
79 user: Some(GithubLogin {
80 login: "zed-zippy[bot]".to_string(),
81 }),
82});
83
84impl CommitAuthor {
85 pub(crate) fn user(&self) -> Option<&GithubLogin> {
86 self.user.as_ref()
87 }
88}
89
90impl PartialEq for CommitAuthor {
91 fn eq(&self, other: &Self) -> bool {
92 self.user.as_ref().zip(other.user.as_ref()).map_or_else(
93 || self.email == other.email || self.name == other.name,
94 |(l, r)| l == r,
95 )
96 }
97}
98
99impl fmt::Display for CommitAuthor {
100 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
101 match self.user.as_ref() {
102 Some(user) => write!(formatter, "{} ({user})", self.name),
103 None => write!(formatter, "{} ({})", self.name, self.email),
104 }
105 }
106}
107
108#[derive(Debug, Deserialize)]
109pub struct CommitAuthors {
110 #[serde(rename = "author")]
111 primary_author: CommitAuthor,
112 #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
113 co_authors: Vec<CommitAuthor>,
114}
115
116impl CommitAuthors {
117 pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
118 self.co_authors.is_empty().not().then(|| {
119 self.co_authors
120 .iter()
121 .filter(|co_author| *co_author != &self.primary_author)
122 })
123 }
124}
125
126#[derive(Debug, Deref)]
127pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
128
129impl AuthorsForCommits {
130 const SHA_PREFIX: &'static str = "commit";
131}
132
133impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
134 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
135 where
136 D: serde::Deserializer<'de>,
137 {
138 let raw = HashMap::<String, CommitAuthors>::deserialize(deserializer)?;
139 let map = raw
140 .into_iter()
141 .map(|(key, value)| {
142 let sha = key
143 .strip_prefix(AuthorsForCommits::SHA_PREFIX)
144 .unwrap_or(&key);
145 (CommitSha::new(sha.to_owned()), value)
146 })
147 .collect();
148 Ok(Self(map))
149 }
150}
151
152#[derive(Clone)]
153pub struct Repository<'a> {
154 owner: Cow<'a, str>,
155 name: Cow<'a, str>,
156}
157
158impl<'a> Repository<'a> {
159 pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
160
161 pub fn new(owner: &'a str, name: &'a str) -> Self {
162 Self {
163 owner: Cow::Borrowed(owner),
164 name: Cow::Borrowed(name),
165 }
166 }
167
168 pub fn owner(&self) -> &str {
169 &self.owner
170 }
171
172 pub fn name(&self) -> &str {
173 &self.name
174 }
175}
176
177impl Repository<'static> {
178 pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
179 Self {
180 owner: Cow::Borrowed(owner),
181 name: Cow::Borrowed(name),
182 }
183 }
184}
185
186#[async_trait::async_trait(?Send)]
187pub trait GithubApiClient {
188 async fn get_pull_request(
189 &self,
190 repo: &Repository<'_>,
191 pr_number: u64,
192 ) -> Result<PullRequestData>;
193 async fn get_pull_request_reviews(
194 &self,
195 repo: &Repository<'_>,
196 pr_number: u64,
197 ) -> Result<Vec<PullRequestReview>>;
198 async fn get_pull_request_comments(
199 &self,
200 repo: &Repository<'_>,
201 pr_number: u64,
202 ) -> Result<Vec<PullRequestComment>>;
203 async fn get_commit_authors(
204 &self,
205 repo: &Repository<'_>,
206 commit_shas: &[&CommitSha],
207 ) -> Result<AuthorsForCommits>;
208 async fn check_repo_write_permission(
209 &self,
210 repo: &Repository<'_>,
211 login: &GithubLogin,
212 ) -> Result<bool>;
213 async fn add_label_to_issue(
214 &self,
215 repo: &Repository<'_>,
216 label: &str,
217 issue_number: u64,
218 ) -> Result<()>;
219}
220
221#[derive(Deref)]
222pub struct GithubClient {
223 api: Rc<dyn GithubApiClient>,
224}
225
226impl GithubClient {
227 pub fn new(api: Rc<dyn GithubApiClient>) -> Self {
228 Self { api }
229 }
230
231 #[cfg(feature = "octo-client")]
232 pub async fn for_app_in_repo(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
233 let client = OctocrabClient::new(app_id, app_private_key, org).await?;
234 Ok(Self::new(Rc::new(client)))
235 }
236}
237
238pub mod graph_ql {
239 use anyhow::{Context as _, Result};
240 use itertools::Itertools as _;
241 use serde::Deserialize;
242
243 use crate::git::CommitSha;
244
245 use super::AuthorsForCommits;
246
247 #[derive(Debug, Deserialize)]
248 pub struct GraphQLResponse<T> {
249 pub data: Option<T>,
250 pub errors: Option<Vec<GraphQLError>>,
251 }
252
253 impl<T> GraphQLResponse<T> {
254 pub fn into_data(self) -> Result<T> {
255 if let Some(errors) = &self.errors {
256 if !errors.is_empty() {
257 let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
258 anyhow::bail!("GraphQL error: {messages}");
259 }
260 }
261
262 self.data.context("GraphQL response contained no data")
263 }
264 }
265
266 #[derive(Debug, Deserialize)]
267 pub struct GraphQLError {
268 pub message: String,
269 }
270
271 #[derive(Debug, Deserialize)]
272 pub struct CommitAuthorsResponse {
273 pub repository: AuthorsForCommits,
274 }
275
276 pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
277 where
278 T: Deserialize<'de>,
279 D: serde::Deserializer<'de>,
280 {
281 #[derive(Deserialize)]
282 struct Nodes<T> {
283 nodes: Vec<T>,
284 }
285 Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
286 }
287
288 pub fn build_co_authors_query<'a>(
289 org: &str,
290 repo: &str,
291 shas: impl IntoIterator<Item = &'a CommitSha>,
292 ) -> String {
293 const FRAGMENT: &str = r#"
294 ... on Commit {
295 author {
296 name
297 email
298 user { login }
299 }
300 authors(first: 10) {
301 nodes {
302 name
303 email
304 user { login }
305 }
306 }
307 }
308 "#;
309
310 let objects = shas
311 .into_iter()
312 .map(|commit_sha| {
313 format!(
314 "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
315 sha_prefix = AuthorsForCommits::SHA_PREFIX,
316 sha = **commit_sha,
317 )
318 })
319 .join("\n");
320
321 format!("{{ repository(owner: \"{org}\", name: \"{repo}\") {{ {objects} }} }}")
322 .replace("\n", "")
323 }
324}
325
326#[cfg(feature = "octo-client")]
327mod octo_client {
328 use anyhow::{Context, Result};
329 use futures::TryStreamExt as _;
330 use jsonwebtoken::EncodingKey;
331 use octocrab::{
332 Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
333 service::middleware::cache::mem::InMemoryCache,
334 };
335 use serde::de::DeserializeOwned;
336 use tokio::pin;
337
338 use crate::{
339 git::CommitSha,
340 github::{Repository, graph_ql},
341 };
342
343 use super::{
344 AuthorsForCommits, GithubApiClient, GithubLogin, GithubUser, PullRequestComment,
345 PullRequestData, PullRequestReview, ReviewState,
346 };
347
348 const PAGE_SIZE: u8 = 100;
349
350 pub struct OctocrabClient {
351 client: Octocrab,
352 }
353
354 impl OctocrabClient {
355 pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
356 let octocrab = Octocrab::builder()
357 .cache(InMemoryCache::new())
358 .app(
359 app_id.into(),
360 EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
361 )
362 .build()?;
363
364 let installations = octocrab
365 .apps()
366 .installations()
367 .send()
368 .await
369 .context("Failed to fetch installations")?
370 .take_items();
371
372 let installation_id = installations
373 .into_iter()
374 .find(|installation| installation.account.login == org)
375 .context("Could not find Zed repository in installations")?
376 .id;
377
378 let client = octocrab.installation(installation_id)?;
379 Ok(Self { client })
380 }
381
382 async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
383 let response: serde_json::Value = self.client.graphql(query).await?;
384 let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
385 .context("Failed to parse GraphQL response envelope")?;
386 parsed.into_data()
387 }
388
389 async fn get_all<T: DeserializeOwned + 'static>(
390 &self,
391 page: Page<T>,
392 ) -> octocrab::Result<Vec<T>> {
393 self.get_filtered(page, |_| true).await
394 }
395
396 async fn get_filtered<T: DeserializeOwned + 'static>(
397 &self,
398 page: Page<T>,
399 predicate: impl Fn(&T) -> bool,
400 ) -> octocrab::Result<Vec<T>> {
401 let stream = page.into_stream(&self.client);
402 pin!(stream);
403
404 let mut results = Vec::new();
405
406 while let Some(item) = stream.try_next().await?
407 && predicate(&item)
408 {
409 results.push(item);
410 }
411
412 Ok(results)
413 }
414 }
415
416 #[async_trait::async_trait(?Send)]
417 impl GithubApiClient for OctocrabClient {
418 async fn get_pull_request(
419 &self,
420 repo: &Repository<'_>,
421 pr_number: u64,
422 ) -> Result<PullRequestData> {
423 let pr = self
424 .client
425 .pulls(repo.owner.as_ref(), repo.name.as_ref())
426 .get(pr_number)
427 .await?;
428 Ok(PullRequestData {
429 number: pr.number,
430 user: pr.user.map(|user| GithubUser { login: user.login }),
431 merged_by: pr.merged_by.map(|user| GithubUser { login: user.login }),
432 labels: pr
433 .labels
434 .map(|labels| labels.into_iter().map(|label| label.name).collect()),
435 })
436 }
437
438 async fn get_pull_request_reviews(
439 &self,
440 repo: &Repository<'_>,
441 pr_number: u64,
442 ) -> Result<Vec<PullRequestReview>> {
443 let page = self
444 .client
445 .pulls(repo.owner.as_ref(), repo.name.as_ref())
446 .list_reviews(pr_number)
447 .per_page(PAGE_SIZE)
448 .send()
449 .await?;
450
451 let reviews = self.get_all(page).await?;
452
453 Ok(reviews
454 .into_iter()
455 .map(|review| PullRequestReview {
456 user: review.user.map(|user| GithubUser { login: user.login }),
457 state: review.state.map(|state| match state {
458 OctocrabReviewState::Approved => ReviewState::Approved,
459 _ => ReviewState::Other,
460 }),
461 body: review.body,
462 })
463 .collect())
464 }
465
466 async fn get_pull_request_comments(
467 &self,
468 repo: &Repository<'_>,
469 pr_number: u64,
470 ) -> Result<Vec<PullRequestComment>> {
471 let page = self
472 .client
473 .issues(repo.owner.as_ref(), repo.name.as_ref())
474 .list_comments(pr_number)
475 .per_page(PAGE_SIZE)
476 .send()
477 .await?;
478
479 let comments = self.get_all(page).await?;
480
481 Ok(comments
482 .into_iter()
483 .map(|comment| PullRequestComment {
484 user: GithubUser {
485 login: comment.user.login,
486 },
487 body: comment.body,
488 })
489 .collect())
490 }
491
492 async fn get_commit_authors(
493 &self,
494 repo: &Repository<'_>,
495 commit_shas: &[&CommitSha],
496 ) -> Result<AuthorsForCommits> {
497 let query = graph_ql::build_co_authors_query(
498 repo.owner.as_ref(),
499 repo.name.as_ref(),
500 commit_shas.iter().copied(),
501 );
502 let query = serde_json::json!({ "query": query });
503 self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
504 .await
505 .map(|response| response.repository)
506 }
507
508 async fn check_repo_write_permission(
509 &self,
510 repo: &Repository<'_>,
511 login: &GithubLogin,
512 ) -> Result<bool> {
513 // Check org membership first - we save ourselves a few request that way
514 let page = self
515 .client
516 .orgs(repo.owner.as_ref())
517 .list_members()
518 .per_page(PAGE_SIZE)
519 .send()
520 .await?;
521
522 let members = self.get_all(page).await?;
523
524 if members
525 .into_iter()
526 .any(|member| member.login == login.as_str())
527 {
528 return Ok(true);
529 }
530
531 // TODO: octocrab fails to deserialize the permission response and
532 // does not adhere to the scheme laid out at
533 // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
534
535 #[derive(serde::Deserialize)]
536 #[serde(rename_all = "lowercase")]
537 enum RepoPermission {
538 Admin,
539 Write,
540 Read,
541 #[serde(other)]
542 Other,
543 }
544
545 #[derive(serde::Deserialize)]
546 struct RepositoryPermissions {
547 permission: RepoPermission,
548 }
549
550 self.client
551 .get::<RepositoryPermissions, _, _>(
552 format!(
553 "/repos/{owner}/{repo}/collaborators/{user}/permission",
554 owner = repo.owner.as_ref(),
555 repo = repo.name.as_ref(),
556 user = login.as_str()
557 ),
558 None::<&()>,
559 )
560 .await
561 .map(|response| {
562 matches!(
563 response.permission,
564 RepoPermission::Write | RepoPermission::Admin
565 )
566 })
567 .map_err(Into::into)
568 }
569
570 async fn add_label_to_issue(
571 &self,
572 repo: &Repository<'_>,
573 label: &str,
574 issue_number: u64,
575 ) -> Result<()> {
576 self.client
577 .issues(repo.owner.as_ref(), repo.name.as_ref())
578 .add_labels(issue_number, &[label.to_owned()])
579 .await
580 .map(|_| ())
581 .map_err(Into::into)
582 }
583 }
584}
585
586#[cfg(feature = "octo-client")]
587pub use octo_client::OctocrabClient;