1use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
2
3use anyhow::Result;
4use derive_more::Deref;
5use serde::Deserialize;
6
7use crate::git::CommitSha;
8
9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
10
11#[derive(Debug, Clone)]
12pub struct GitHubUser {
13 pub login: String,
14}
15
16#[derive(Debug, Clone)]
17pub struct PullRequestData {
18 pub number: u64,
19 pub user: Option<GitHubUser>,
20 pub merged_by: Option<GitHubUser>,
21 pub labels: Option<Vec<String>>,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum ReviewState {
26 Approved,
27 Other,
28}
29
30#[derive(Debug, Clone)]
31pub struct PullRequestReview {
32 pub user: Option<GitHubUser>,
33 pub state: Option<ReviewState>,
34 pub body: Option<String>,
35}
36
37impl PullRequestReview {
38 pub fn with_body(self, body: impl ToString) -> Self {
39 Self {
40 body: Some(body.to_string()),
41 ..self
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub struct PullRequestComment {
48 pub user: GitHubUser,
49 pub body: Option<String>,
50}
51
52#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
53pub struct GithubLogin {
54 login: String,
55}
56
57impl GithubLogin {
58 pub fn new(login: String) -> Self {
59 Self { login }
60 }
61}
62
63impl fmt::Display for GithubLogin {
64 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
65 write!(formatter, "@{}", self.login)
66 }
67}
68
69#[derive(Debug, Deserialize, Clone)]
70pub struct CommitAuthor {
71 name: String,
72 email: String,
73 user: Option<GithubLogin>,
74}
75
76impl CommitAuthor {
77 pub(crate) fn user(&self) -> Option<&GithubLogin> {
78 self.user.as_ref()
79 }
80}
81
82impl PartialEq for CommitAuthor {
83 fn eq(&self, other: &Self) -> bool {
84 self.user.as_ref().zip(other.user.as_ref()).map_or_else(
85 || self.email == other.email || self.name == other.name,
86 |(l, r)| l == r,
87 )
88 }
89}
90
91impl fmt::Display for CommitAuthor {
92 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
93 match self.user.as_ref() {
94 Some(user) => write!(formatter, "{} ({user})", self.name),
95 None => write!(formatter, "{} ({})", self.name, self.email),
96 }
97 }
98}
99
100#[derive(Debug, Deserialize)]
101pub struct CommitAuthors {
102 #[serde(rename = "author")]
103 primary_author: CommitAuthor,
104 #[serde(rename = "authors", deserialize_with = "graph_ql::deserialize_nodes")]
105 co_authors: Vec<CommitAuthor>,
106}
107
108impl CommitAuthors {
109 pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
110 self.co_authors.is_empty().not().then(|| {
111 self.co_authors
112 .iter()
113 .filter(|co_author| *co_author != &self.primary_author)
114 })
115 }
116}
117
118#[derive(Debug, Deref)]
119pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
120
121impl AuthorsForCommits {
122 const SHA_PREFIX: &'static str = "commit";
123}
124
125impl<'de> serde::Deserialize<'de> for AuthorsForCommits {
126 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
127 where
128 D: serde::Deserializer<'de>,
129 {
130 let raw = HashMap::<String, CommitAuthors>::deserialize(deserializer)?;
131 let map = raw
132 .into_iter()
133 .map(|(key, value)| {
134 let sha = key
135 .strip_prefix(AuthorsForCommits::SHA_PREFIX)
136 .unwrap_or(&key);
137 (CommitSha::new(sha.to_owned()), value)
138 })
139 .collect();
140 Ok(Self(map))
141 }
142}
143
144#[async_trait::async_trait(?Send)]
145pub trait GitHubApiClient {
146 async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
147 async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
148 async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
149 async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
150 async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
151 async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool>;
152 async fn actor_has_repository_write_permission(
153 &self,
154 login: &GithubLogin,
155 ) -> anyhow::Result<bool> {
156 Ok(self.check_org_membership(login).await?
157 || self.check_repo_write_permission(login).await?)
158 }
159 async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()>;
160}
161
162#[derive(Deref)]
163pub struct GitHubClient {
164 api: Rc<dyn GitHubApiClient>,
165}
166
167impl GitHubClient {
168 pub fn new(api: Rc<dyn GitHubApiClient>) -> Self {
169 Self { api }
170 }
171
172 #[cfg(feature = "octo-client")]
173 pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
174 let client = OctocrabClient::new(app_id, app_private_key).await?;
175 Ok(Self::new(Rc::new(client)))
176 }
177}
178
179pub mod graph_ql {
180 use anyhow::{Context as _, Result};
181 use itertools::Itertools as _;
182 use serde::Deserialize;
183
184 use crate::git::CommitSha;
185
186 use super::AuthorsForCommits;
187
188 #[derive(Debug, Deserialize)]
189 pub struct GraphQLResponse<T> {
190 pub data: Option<T>,
191 pub errors: Option<Vec<GraphQLError>>,
192 }
193
194 impl<T> GraphQLResponse<T> {
195 pub fn into_data(self) -> Result<T> {
196 if let Some(errors) = &self.errors {
197 if !errors.is_empty() {
198 let messages: String = errors.iter().map(|e| e.message.as_str()).join("; ");
199 anyhow::bail!("GraphQL error: {messages}");
200 }
201 }
202
203 self.data.context("GraphQL response contained no data")
204 }
205 }
206
207 #[derive(Debug, Deserialize)]
208 pub struct GraphQLError {
209 pub message: String,
210 }
211
212 #[derive(Debug, Deserialize)]
213 pub struct CommitAuthorsResponse {
214 pub repository: AuthorsForCommits,
215 }
216
217 pub fn deserialize_nodes<'de, T, D>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
218 where
219 T: Deserialize<'de>,
220 D: serde::Deserializer<'de>,
221 {
222 #[derive(Deserialize)]
223 struct Nodes<T> {
224 nodes: Vec<T>,
225 }
226 Nodes::<T>::deserialize(deserializer).map(|wrapper| wrapper.nodes)
227 }
228
229 pub fn build_co_authors_query<'a>(
230 org: &str,
231 repo: &str,
232 shas: impl IntoIterator<Item = &'a CommitSha>,
233 ) -> String {
234 const FRAGMENT: &str = r#"
235 ... on Commit {
236 author {
237 name
238 email
239 user { login }
240 }
241 authors(first: 10) {
242 nodes {
243 name
244 email
245 user { login }
246 }
247 }
248 }
249 "#;
250
251 let objects = shas
252 .into_iter()
253 .map(|commit_sha| {
254 format!(
255 "{sha_prefix}{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
256 sha_prefix = AuthorsForCommits::SHA_PREFIX,
257 sha = **commit_sha,
258 )
259 })
260 .join("\n");
261
262 format!("{{ repository(owner: \"{org}\", name: \"{repo}\") {{ {objects} }} }}")
263 .replace("\n", "")
264 }
265}
266
267#[cfg(feature = "octo-client")]
268mod octo_client {
269 use anyhow::{Context, Result};
270 use futures::TryStreamExt as _;
271 use jsonwebtoken::EncodingKey;
272 use octocrab::{
273 Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
274 service::middleware::cache::mem::InMemoryCache,
275 };
276 use serde::de::DeserializeOwned;
277 use tokio::pin;
278
279 use crate::{git::CommitSha, github::graph_ql};
280
281 use super::{
282 AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
283 PullRequestData, PullRequestReview, ReviewState,
284 };
285
286 const PAGE_SIZE: u8 = 100;
287 const ORG: &str = "zed-industries";
288 const REPO: &str = "zed";
289
290 pub struct OctocrabClient {
291 client: Octocrab,
292 }
293
294 impl OctocrabClient {
295 pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
296 let octocrab = Octocrab::builder()
297 .cache(InMemoryCache::new())
298 .app(
299 app_id.into(),
300 EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
301 )
302 .build()?;
303
304 let installations = octocrab
305 .apps()
306 .installations()
307 .send()
308 .await
309 .context("Failed to fetch installations")?
310 .take_items();
311
312 let installation_id = installations
313 .into_iter()
314 .find(|installation| installation.account.login == ORG)
315 .context("Could not find Zed repository in installations")?
316 .id;
317
318 let client = octocrab.installation(installation_id)?;
319 Ok(Self { client })
320 }
321
322 async fn graphql<R: DeserializeOwned>(&self, query: &serde_json::Value) -> Result<R> {
323 let response: serde_json::Value = self.client.graphql(query).await?;
324 let parsed: graph_ql::GraphQLResponse<R> = serde_json::from_value(response)
325 .context("Failed to parse GraphQL response envelope")?;
326 parsed.into_data()
327 }
328
329 async fn get_all<T: DeserializeOwned + 'static>(
330 &self,
331 page: Page<T>,
332 ) -> octocrab::Result<Vec<T>> {
333 self.get_filtered(page, |_| true).await
334 }
335
336 async fn get_filtered<T: DeserializeOwned + 'static>(
337 &self,
338 page: Page<T>,
339 predicate: impl Fn(&T) -> bool,
340 ) -> octocrab::Result<Vec<T>> {
341 let stream = page.into_stream(&self.client);
342 pin!(stream);
343
344 let mut results = Vec::new();
345
346 while let Some(item) = stream.try_next().await?
347 && predicate(&item)
348 {
349 results.push(item);
350 }
351
352 Ok(results)
353 }
354 }
355
356 #[async_trait::async_trait(?Send)]
357 impl GitHubApiClient for OctocrabClient {
358 async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
359 let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
360 Ok(PullRequestData {
361 number: pr.number,
362 user: pr.user.map(|user| GitHubUser { login: user.login }),
363 merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
364 labels: pr
365 .labels
366 .map(|labels| labels.into_iter().map(|label| label.name).collect()),
367 })
368 }
369
370 async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
371 let page = self
372 .client
373 .pulls(ORG, REPO)
374 .list_reviews(pr_number)
375 .per_page(PAGE_SIZE)
376 .send()
377 .await?;
378
379 let reviews = self.get_all(page).await?;
380
381 Ok(reviews
382 .into_iter()
383 .map(|review| PullRequestReview {
384 user: review.user.map(|user| GitHubUser { login: user.login }),
385 state: review.state.map(|state| match state {
386 OctocrabReviewState::Approved => ReviewState::Approved,
387 _ => ReviewState::Other,
388 }),
389 body: review.body,
390 })
391 .collect())
392 }
393
394 async fn get_pull_request_comments(
395 &self,
396 pr_number: u64,
397 ) -> Result<Vec<PullRequestComment>> {
398 let page = self
399 .client
400 .issues(ORG, REPO)
401 .list_comments(pr_number)
402 .per_page(PAGE_SIZE)
403 .send()
404 .await?;
405
406 let comments = self.get_all(page).await?;
407
408 Ok(comments
409 .into_iter()
410 .map(|comment| PullRequestComment {
411 user: GitHubUser {
412 login: comment.user.login,
413 },
414 body: comment.body,
415 })
416 .collect())
417 }
418
419 async fn get_commit_authors(
420 &self,
421 commit_shas: &[&CommitSha],
422 ) -> Result<AuthorsForCommits> {
423 let query = graph_ql::build_co_authors_query(ORG, REPO, commit_shas.iter().copied());
424 let query = serde_json::json!({ "query": query });
425 self.graphql::<graph_ql::CommitAuthorsResponse>(&query)
426 .await
427 .map(|response| response.repository)
428 }
429
430 async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
431 let page = self
432 .client
433 .orgs(ORG)
434 .list_members()
435 .per_page(PAGE_SIZE)
436 .send()
437 .await?;
438
439 let members = self.get_all(page).await?;
440
441 Ok(members
442 .into_iter()
443 .any(|member| member.login == login.as_str()))
444 }
445
446 async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool> {
447 // TODO: octocrab fails to deserialize the permission response and
448 // does not adhere to the scheme laid out at
449 // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
450
451 #[derive(serde::Deserialize)]
452 #[serde(rename_all = "lowercase")]
453 enum RepoPermission {
454 Admin,
455 Write,
456 Read,
457 #[serde(other)]
458 Other,
459 }
460
461 #[derive(serde::Deserialize)]
462 struct RepositoryPermissions {
463 permission: RepoPermission,
464 }
465
466 self.client
467 .get::<RepositoryPermissions, _, _>(
468 format!(
469 "/repos/{ORG}/{REPO}/collaborators/{user}/permission",
470 user = login.as_str()
471 ),
472 None::<&()>,
473 )
474 .await
475 .map(|response| {
476 matches!(
477 response.permission,
478 RepoPermission::Write | RepoPermission::Admin
479 )
480 })
481 .map_err(Into::into)
482 }
483
484 async fn add_label_to_issue(&self, label: &str, issue_number: u64) -> Result<()> {
485 self.client
486 .issues(ORG, REPO)
487 .add_labels(issue_number, &[label.to_owned()])
488 .await
489 .map(|_| ())
490 .map_err(Into::into)
491 }
492 }
493}
494
495#[cfg(feature = "octo-client")]
496pub use octo_client::OctocrabClient;