1use std::{borrow::Cow, collections::HashMap, fmt, ops::Not, sync::LazyLock};
2
3use anyhow::Result;
4use derive_more::Deref;
5use serde::Deserialize;
6
7use crate::git::CommitSha;
8
9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
10
11#[derive(Debug, Clone)]
12pub struct GithubUser {
13 pub login: String,
14}
15
16#[derive(Debug, Clone)]
17pub struct PullRequestData {
18 pub number: u64,
19 pub user: Option<GithubUser>,
20 pub merged_by: Option<GithubUser>,
21 pub labels: Option<Vec<String>>,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum ReviewState {
26 Approved,
27 Other,
28}
29
30#[derive(Debug, Clone)]
31pub struct PullRequestReview {
32 pub user: Option<GithubUser>,
33 pub state: Option<ReviewState>,
34 pub body: Option<String>,
35}
36
37impl PullRequestReview {
38 pub fn with_body(self, body: impl ToString) -> Self {
39 Self {
40 body: Some(body.to_string()),
41 ..self
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub struct PullRequestComment {
48 pub user: GithubUser,
49 pub body: Option<String>,
50}
51
52#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
53pub struct GithubLogin {
54 login: String,
55}
56
57impl GithubLogin {
58 pub fn new(login: String) -> Self {
59 Self { login }
60 }
61}
62
63impl fmt::Display for GithubLogin {
64 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
65 write!(formatter, "@{}", self.login)
66 }
67}
68
69#[derive(Debug, Deserialize, Clone)]
70pub struct CommitAuthor {
71 name: String,
72 email: String,
73 user: Option<GithubLogin>,
74}
75
76pub(crate) static ZED_ZIPPY_AUTHOR: LazyLock<CommitAuthor> = LazyLock::new(|| CommitAuthor {
77 name: "Zed Zippy".to_string(),
78 email: "234243425+zed-zippy[bot]@users.noreply.github.com".to_string(),
79 user: Some(GithubLogin {
80 login: "zed-zippy[bot]".to_string(),
81 }),
82});
83
84impl CommitAuthor {
85 pub(crate) fn user(&self) -> Option<&GithubLogin> {
86 self.user.as_ref()
87 }
88}
89
90impl PartialEq for CommitAuthor {
91 fn eq(&self, other: &Self) -> bool {
92 self.user.as_ref().zip(other.user.as_ref()).map_or_else(
93 || self.email == other.email || self.name == other.name,
94 |(l, r)| l == r,
95 )
96 }
97}
98
99impl fmt::Display for CommitAuthor {
100 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
101 match self.user.as_ref() {
102 Some(user) => write!(formatter, "{} ({user})", self.name),
103 None => write!(formatter, "{} ({})", self.name, self.email),
104 }
105 }
106}
107
108#[derive(Debug, Deserialize)]
109pub struct CommitAuthors {
110 #[serde(rename = "author")]
111 primary_author: CommitAuthor,
112 #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
113 co_authors: Vec<CommitAuthor>,
114}
115
116impl CommitAuthors {
117 pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
118 self.co_authors.is_empty().not().then(|| {
119 self.co_authors
120 .iter()
121 .filter(|co_author| *co_author != &self.primary_author)
122 })
123 }
124}
125
126#[derive(Debug, Deref)]
127pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
128
129impl AuthorsForCommits {
130 const SHA_PREFIX: &'static str = "commit";
131}
132
133impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
134 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
135 where
136 D: serde::Deserializer<'de>,
137 {
138 let raw = HashMap::<String, CommitAuthors>::deserialize(deserializer)?;
139 let map = raw
140 .into_iter()
141 .map(|(key, value)| {
142 let sha = key
143 .strip_prefix(AuthorsForCommits::SHA_PREFIX)
144 .unwrap_or(&key);
145 (CommitSha::new(sha.to_owned()), value)
146 })
147 .collect();
148 Ok(Self(map))
149 }
150}
151
152#[derive(Clone)]
153pub struct Repository<'a> {
154 owner: Cow<'a, str>,
155 name: Cow<'a, str>,
156}
157
158impl<'a> Repository<'a> {
159 pub const ZED: Repository<'static> = Repository::new_static("zed-industries", "zed");
160
161 pub fn new(owner: &'a str, name: &'a str) -> Self {
162 Self {
163 owner: Cow::Borrowed(owner),
164 name: Cow::Borrowed(name),
165 }
166 }
167
168 pub fn owner(&self) -> &str {
169 &self.owner
170 }
171
172 pub fn name(&self) -> &str {
173 &self.name
174 }
175}
176
177impl Repository<'static> {
178 pub const fn new_static(owner: &'static str, name: &'static str) -> Self {
179 Self {
180 owner: Cow::Borrowed(owner),
181 name: Cow::Borrowed(name),
182 }
183 }
184}
185
186#[async_trait::async_trait(?Send)]
187pub trait GithubApiClient {
188 async fn get_pull_request(
189 &self,
190 repo: &Repository<'_>,
191 pr_number: u64,
192 ) -> Result<PullRequestData>;
193 async fn get_pull_request_reviews(
194 &self,
195 repo: &Repository<'_>,
196 pr_number: u64,
197 ) -> Result<Vec<PullRequestReview>>;
198 async fn get_pull_request_comments(
199 &self,
200 repo: &Repository<'_>,
201 pr_number: u64,
202 ) -> Result<Vec<PullRequestComment>>;
203 async fn get_commit_authors(
204 &self,
205 repo: &Repository<'_>,
206 commit_shas: &[&CommitSha],
207 ) -> Result<AuthorsForCommits>;
208 async fn check_repo_write_permission(
209 &self,
210 repo: &Repository<'_>,
211 login: &GithubLogin,
212 ) -> Result<bool>;
213 async fn add_label_to_issue(
214 &self,
215 repo: &Repository<'_>,
216 label: &str,
217 issue_number: u64,
218 ) -> Result<()>;
219}
220
221pub mod graph_ql {
222 use anyhow::{Context as _, Result};
223 use itertools::Itertools as _;
224 use serde::Deserialize;
225
226 use crate::git::CommitSha;
227
228 use super::AuthorsForCommits;
229
230 #[derive(Debug, Deserialize)]
231 pub struct GraphQLResponse<T> {
232 pub data: Option<T>,
233 pub errors: Option<Vec<GraphQLError>>,
234 }
235
236 impl<T> GraphQLResponse<T> {
237 pub fn into_data(self) -> Result<T> {
238 if let Some(errors) = &self.errors {
239 if !errors.is_empty() {
240 let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
241 anyhow::bail!("GraphQL error: {messages}");
242 }
243 }
244
245 self.data.context("GraphQL response contained no data")
246 }
247 }
248
249 #[derive(Debug, Deserialize)]
250 pub struct GraphQLError {
251 pub message: String,
252 }
253
254 #[derive(Debug, Deserialize)]
255 pub struct CommitAuthorsResponse {
256 pub repository: AuthorsForCommits,
257 }
258
259 pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
260 where
261 T: Deserialize<'de>,
262 D: serde::Deserializer<'de>,
263 {
264 #[derive(Deserialize)]
265 struct Nodes<T> {
266 nodes: Vec<T>,
267 }
268 Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
269 }
270
271 pub fn build_co_authors_query<'a>(
272 org: &str,
273 repo: &str,
274 shas: impl IntoIterator<Item = &'a CommitSha>,
275 ) -> String {
276 const FRAGMENT: &str = r#"
277 ... on Commit {
278 author {
279 name
280 email
281 user { login }
282 }
283 authors(first: 10) {
284 nodes {
285 name
286 email
287 user { login }
288 }
289 }
290 }
291 "#;
292
293 let objects = shas
294 .into_iter()
295 .map(|commit_sha| {
296 format!(
297 "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
298 sha_prefix = AuthorsForCommits::SHA_PREFIX,
299 sha = **commit_sha,
300 )
301 })
302 .join("\n");
303
304 format!("{{ repository(owner: \"{org}\", name: \"{repo}\") {{ {objects} }} }}")
305 .replace("\n", "")
306 }
307}
308
309#[cfg(feature = "octo-client")]
310mod octo_client {
311 use anyhow::{Context, Result};
312 use futures::TryStreamExt as _;
313 use jsonwebtoken::EncodingKey;
314 use octocrab::{
315 Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
316 service::middleware::cache::mem::InMemoryCache,
317 };
318 use serde::de::DeserializeOwned;
319 use tokio::pin;
320
321 use crate::{
322 git::CommitSha,
323 github::{Repository, graph_ql},
324 };
325
326 use super::{
327 AuthorsForCommits, GithubApiClient, GithubLogin, GithubUser, PullRequestComment,
328 PullRequestData, PullRequestReview, ReviewState,
329 };
330
331 const PAGE_SIZE: u8 = 100;
332
333 pub struct OctocrabClient {
334 client: Octocrab,
335 }
336
337 impl OctocrabClient {
338 pub async fn new(app_id: u64, app_private_key: &str, org: &str) -> Result<Self> {
339 let octocrab = Octocrab::builder()
340 .cache(InMemoryCache::new())
341 .app(
342 app_id.into(),
343 EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
344 )
345 .build()?;
346
347 let installations = octocrab
348 .apps()
349 .installations()
350 .send()
351 .await
352 .context("Failed to fetch installations")?
353 .take_items();
354
355 let installation_id = installations
356 .into_iter()
357 .find(|installation| installation.account.login == org)
358 .context("Could not find Zed repository in installations")?
359 .id;
360
361 let client = octocrab.installation(installation_id)?;
362 Ok(Self { client })
363 }
364
365 async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
366 let response: serde_json::Value = self.client.graphql(query).await?;
367 let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
368 .context("Failed to parse GraphQL response envelope")?;
369 parsed.into_data()
370 }
371
372 async fn get_all<T: DeserializeOwned + 'static>(
373 &self,
374 page: Page<T>,
375 ) -> octocrab::Result<Vec<T>> {
376 self.get_filtered(page, |_| true).await
377 }
378
379 async fn get_filtered<T: DeserializeOwned + 'static>(
380 &self,
381 page: Page<T>,
382 predicate: impl Fn(&T) -> bool,
383 ) -> octocrab::Result<Vec<T>> {
384 let stream = page.into_stream(&self.client);
385 pin!(stream);
386
387 let mut results = Vec::new();
388
389 while let Some(item) = stream.try_next().await?
390 && predicate(&item)
391 {
392 results.push(item);
393 }
394
395 Ok(results)
396 }
397 }
398
399 #[async_trait::async_trait(?Send)]
400 impl GithubApiClient for OctocrabClient {
401 async fn get_pull_request(
402 &self,
403 repo: &Repository<'_>,
404 pr_number: u64,
405 ) -> Result<PullRequestData> {
406 let pr = self
407 .client
408 .pulls(repo.owner.as_ref(), repo.name.as_ref())
409 .get(pr_number)
410 .await?;
411 Ok(PullRequestData {
412 number: pr.number,
413 user: pr.user.map(|user| GithubUser { login: user.login }),
414 merged_by: pr.merged_by.map(|user| GithubUser { login: user.login }),
415 labels: pr
416 .labels
417 .map(|labels| labels.into_iter().map(|label| label.name).collect()),
418 })
419 }
420
421 async fn get_pull_request_reviews(
422 &self,
423 repo: &Repository<'_>,
424 pr_number: u64,
425 ) -> Result<Vec<PullRequestReview>> {
426 let page = self
427 .client
428 .pulls(repo.owner.as_ref(), repo.name.as_ref())
429 .list_reviews(pr_number)
430 .per_page(PAGE_SIZE)
431 .send()
432 .await?;
433
434 let reviews = self.get_all(page).await?;
435
436 Ok(reviews
437 .into_iter()
438 .map(|review| PullRequestReview {
439 user: review.user.map(|user| GithubUser { login: user.login }),
440 state: review.state.map(|state| match state {
441 OctocrabReviewState::Approved => ReviewState::Approved,
442 _ => ReviewState::Other,
443 }),
444 body: review.body,
445 })
446 .collect())
447 }
448
449 async fn get_pull_request_comments(
450 &self,
451 repo: &Repository<'_>,
452 pr_number: u64,
453 ) -> Result<Vec<PullRequestComment>> {
454 let page = self
455 .client
456 .issues(repo.owner.as_ref(), repo.name.as_ref())
457 .list_comments(pr_number)
458 .per_page(PAGE_SIZE)
459 .send()
460 .await?;
461
462 let comments = self.get_all(page).await?;
463
464 Ok(comments
465 .into_iter()
466 .map(|comment| PullRequestComment {
467 user: GithubUser {
468 login: comment.user.login,
469 },
470 body: comment.body,
471 })
472 .collect())
473 }
474
475 async fn get_commit_authors(
476 &self,
477 repo: &Repository<'_>,
478 commit_shas: &[&CommitSha],
479 ) -> Result<AuthorsForCommits> {
480 let query = graph_ql::build_co_authors_query(
481 repo.owner.as_ref(),
482 repo.name.as_ref(),
483 commit_shas.iter().copied(),
484 );
485 let query = serde_json::json!({ "query": query });
486 self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
487 .await
488 .map(|response| response.repository)
489 }
490
491 async fn check_repo_write_permission(
492 &self,
493 repo: &Repository<'_>,
494 login: &GithubLogin,
495 ) -> Result<bool> {
496 // Check org membership first - we save ourselves a few request that way
497 let page = self
498 .client
499 .orgs(repo.owner.as_ref())
500 .list_members()
501 .per_page(PAGE_SIZE)
502 .send()
503 .await?;
504
505 let members = self.get_all(page).await?;
506
507 if members
508 .into_iter()
509 .any(|member| member.login == login.as_str())
510 {
511 return Ok(true);
512 }
513
514 // TODO: octocrab fails to deserialize the permission response and
515 // does not adhere to the scheme laid out at
516 // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
517
518 #[derive(serde::Deserialize)]
519 #[serde(rename_all = "lowercase")]
520 enum RepoPermission {
521 Admin,
522 Write,
523 Read,
524 #[serde(other)]
525 Other,
526 }
527
528 #[derive(serde::Deserialize)]
529 struct RepositoryPermissions {
530 permission: RepoPermission,
531 }
532
533 self.client
534 .get::<RepositoryPermissions, _, _>(
535 format!(
536 "/repos/{owner}/{repo}/collaborators/{user}/permission",
537 owner = repo.owner.as_ref(),
538 repo = repo.name.as_ref(),
539 user = login.as_str()
540 ),
541 None::<&()>,
542 )
543 .await
544 .map(|response| {
545 matches!(
546 response.permission,
547 RepoPermission::Write | RepoPermission::Admin
548 )
549 })
550 .map_err(Into::into)
551 }
552
553 async fn add_label_to_issue(
554 &self,
555 repo: &Repository<'_>,
556 label: &str,
557 issue_number: u64,
558 ) -> Result<()> {
559 self.client
560 .issues(repo.owner.as_ref(), repo.name.as_ref())
561 .add_labels(issue_number, &[label.to_owned()])
562 .await
563 .map(|_| ())
564 .map_err(Into::into)
565 }
566 }
567}
568
569#[cfg(feature = "octo-client")]
570pub use octo_client::OctocrabClient;