1use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
2
3use anyhow::Result;
4use derive_more::Deref;
5use serde::Deserialize;
6
7use crate::git::CommitSha;
8
9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
10
11#[derive(Debug, Clone)]
12pub struct GitHubUser {
13 pub login: String,
14}
15
16#[derive(Debug, Clone)]
17pub struct PullRequestData {
18 pub number: u64,
19 pub user: Option<GitHubUser>,
20 pub merged_by: Option<GitHubUser>,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ReviewState {
25 Approved,
26 Other,
27}
28
29#[derive(Debug, Clone)]
30pub struct PullRequestReview {
31 pub user: Option<GitHubUser>,
32 pub state: Option<ReviewState>,
33}
34
35#[derive(Debug, Clone)]
36pub struct PullRequestComment {
37 pub user: GitHubUser,
38 pub body: Option<String>,
39}
40
41#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
42pub struct GithubLogin {
43 login: String,
44}
45
46impl GithubLogin {
47 pub fn new(login: String) -> Self {
48 Self { login }
49 }
50}
51
52impl fmt::Display for GithubLogin {
53 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
54 write!(formatter, "@{}", self.login)
55 }
56}
57
58#[derive(Debug, Deserialize, Clone)]
59pub struct CommitAuthor {
60 name: String,
61 email: String,
62 user: Option<GithubLogin>,
63}
64
65impl CommitAuthor {
66 pub(crate) fn user(&self) -> Option<&GithubLogin> {
67 self.user.as_ref()
68 }
69}
70
71impl PartialEq for CommitAuthor {
72 fn eq(&self, other: &Self) -> bool {
73 self.user.as_ref().zip(other.user.as_ref()).map_or_else(
74 || self.email == other.email || self.name == other.name,
75 |(l, r)| l == r,
76 )
77 }
78}
79
80impl fmt::Display for CommitAuthor {
81 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self.user.as_ref() {
83 Some(user) => write!(formatter, "{} ({user})", self.name),
84 None => write!(formatter, "{} ({})", self.name, self.email),
85 }
86 }
87}
88
89#[derive(Debug, Deserialize)]
90pub struct CommitAuthors {
91 #[serde(rename = "author")]
92 primary_author: CommitAuthor,
93 #[serde(rename = "authors")]
94 co_authors: Vec<CommitAuthor>,
95}
96
97impl CommitAuthors {
98 pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
99 self.co_authors.is_empty().not().then(|| {
100 self.co_authors
101 .iter()
102 .filter(|co_author| *co_author != &self.primary_author)
103 })
104 }
105}
106
107#[derive(Debug, Deserialize, Deref)]
108pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
109
110#[async_trait::async_trait(?Send)]
111pub trait GitHubApiClient {
112 async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
113 async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
114 async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
115 async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
116 async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
117 async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool>;
118 async fn actor_has_repository_write_permission(
119 &self,
120 login: &GithubLogin,
121 ) -> anyhow::Result<bool> {
122 Ok(self.check_org_membership(login).await?
123 || self.check_repo_write_permission(login).await?)
124 }
125 async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>;
126}
127
128#[derive(Deref)]
129pub struct GitHubClient {
130 api: Rc<dyn GitHubApiClient>,
131}
132
133impl GitHubClient {
134 pub fn new(api: Rc<dyn GitHubApiClient>) -> Self {
135 Self { api }
136 }
137
138 #[cfg(feature = "octo-client")]
139 pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
140 let client = OctocrabClient::new(app_id, app_private_key).await?;
141 Ok(Self::new(Rc::new(client)))
142 }
143}
144
145#[cfg(feature = "octo-client")]
146mod octo_client {
147 use anyhow::{Context, Result};
148 use futures::TryStreamExt as _;
149 use itertools::Itertools;
150 use jsonwebtoken::EncodingKey;
151 use octocrab::{
152 Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
153 service::middleware::cache::mem::InMemoryCache,
154 };
155 use serde::de::DeserializeOwned;
156 use tokio::pin;
157
158 use crate::git::CommitSha;
159
160 use super::{
161 AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
162 PullRequestData, PullRequestReview, ReviewState,
163 };
164
165 const PAGE_SIZE: u8 = 100;
166 const ORG: &str = "zed-industries";
167 const REPO: &str = "zed";
168
169 pub struct OctocrabClient {
170 client: Octocrab,
171 }
172
173 impl OctocrabClient {
174 pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
175 let octocrab = Octocrab::builder()
176 .cache(InMemoryCache::new())
177 .app(
178 app_id.into(),
179 EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
180 )
181 .build()?;
182
183 let installations = octocrab
184 .apps()
185 .installations()
186 .send()
187 .await
188 .context("Failed to fetch installations")?
189 .take_items();
190
191 let installation_id = installations
192 .into_iter()
193 .find(|installation| installation.account.login == ORG)
194 .context("Could not find Zed repository in installations")?
195 .id;
196
197 let client = octocrab.installation(installation_id)?;
198 Ok(Self { client })
199 }
200
201 fn build_co_authors_query<'a>(shas: impl IntoIterator<Item = &'a CommitSha>) -> String {
202 const FRAGMENT: &str = r#"
203 ... on Commit {
204 author {
205 name
206 email
207 user { login }
208 }
209 authors(first: 10) {
210 nodes {
211 name
212 email
213 user { login }
214 }
215 }
216 }
217 "#;
218
219 let objects: String = shas
220 .into_iter()
221 .map(|commit_sha| {
222 format!(
223 "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
224 sha = **commit_sha
225 )
226 })
227 .join("\n");
228
229 format!("{{ repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects} }} }}")
230 .replace("\n", "")
231 }
232
233 async fn graphql<R: octocrab::FromResponse>(
234 &self,
235 query: &serde_json::Value,
236 ) -> octocrab::Result<R> {
237 self.client.graphql(query).await
238 }
239
240 async fn get_all<T: DeserializeOwned + 'static>(
241 &self,
242 page: Page<T>,
243 ) -> octocrab::Result<Vec<T>> {
244 self.get_filtered(page, |_| true).await
245 }
246
247 async fn get_filtered<T: DeserializeOwned + 'static>(
248 &self,
249 page: Page<T>,
250 predicate: impl Fn(&T) -> bool,
251 ) -> octocrab::Result<Vec<T>> {
252 let stream = page.into_stream(&self.client);
253 pin!(stream);
254
255 let mut results = Vec::new();
256
257 while let Some(item) = stream.try_next().await?
258 && predicate(&item)
259 {
260 results.push(item);
261 }
262
263 Ok(results)
264 }
265 }
266
267 #[async_trait::async_trait(?Send)]
268 impl GitHubApiClient for OctocrabClient {
269 async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
270 let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
271 Ok(PullRequestData {
272 number: pr.number,
273 user: pr.user.map(|user| GitHubUser { login: user.login }),
274 merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
275 })
276 }
277
278 async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
279 let page = self
280 .client
281 .pulls(ORG, REPO)
282 .list_reviews(pr_number)
283 .per_page(PAGE_SIZE)
284 .send()
285 .await?;
286
287 let reviews = self.get_all(page).await?;
288
289 Ok(reviews
290 .into_iter()
291 .map(|review| PullRequestReview {
292 user: review.user.map(|user| GitHubUser { login: user.login }),
293 state: review.state.map(|state| match state {
294 OctocrabReviewState::Approved => ReviewState::Approved,
295 _ => ReviewState::Other,
296 }),
297 })
298 .collect())
299 }
300
301 async fn get_pull_request_comments(
302 &self,
303 pr_number: u64,
304 ) -> Result<Vec<PullRequestComment>> {
305 let page = self
306 .client
307 .issues(ORG, REPO)
308 .list_comments(pr_number)
309 .per_page(PAGE_SIZE)
310 .send()
311 .await?;
312
313 let comments = self.get_all(page).await?;
314
315 Ok(comments
316 .into_iter()
317 .map(|comment| PullRequestComment {
318 user: GitHubUser {
319 login: comment.user.login,
320 },
321 body: comment.body,
322 })
323 .collect())
324 }
325
326 async fn get_commit_authors(
327 &self,
328 commit_shas: &[&CommitSha],
329 ) -> Result<AuthorsForCommits> {
330 let query = Self::build_co_authors_query(commit_shas.iter().copied());
331 let query = serde_json::json!({ "query": query });
332 let mut response = self.graphql::<serde_json::Value>(&query).await?;
333
334 response
335 .get_mut("data")
336 .and_then(|data| data.get_mut("repository"))
337 .and_then(|repo| repo.as_object_mut())
338 .ok_or_else(|| anyhow::anyhow!("Unexpected response format!"))
339 .and_then(|commit_data| {
340 let mut response_map = serde_json::Map::with_capacity(commit_data.len());
341
342 for (key, value) in commit_data.iter_mut() {
343 let key_without_prefix = key.strip_prefix("commit").unwrap_or(key);
344 if let Some(authors) = value.get_mut("authors") {
345 if let Some(nodes) = authors.get("nodes") {
346 *authors = nodes.clone();
347 }
348 }
349
350 response_map.insert(key_without_prefix.to_owned(), value.clone());
351 }
352
353 serde_json::from_value(serde_json::Value::Object(response_map))
354 .context("Failed to deserialize commit authors")
355 })
356 }
357
358 async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
359 let page = self
360 .client
361 .orgs(ORG)
362 .list_members()
363 .per_page(PAGE_SIZE)
364 .send()
365 .await?;
366
367 let members = self.get_all(page).await?;
368
369 Ok(members
370 .into_iter()
371 .any(|member| member.login == login.as_str()))
372 }
373
374 async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool> {
375 // TODO: octocrab fails to deserialize the permission response and
376 // does not adhere to the scheme laid out at
377 // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
378
379 #[derive(serde::Deserialize)]
380 #[serde(rename_all = "lowercase")]
381 enum RepoPermission {
382 Admin,
383 Write,
384 Read,
385 #[serde(other)]
386 Other,
387 }
388
389 #[derive(serde::Deserialize)]
390 struct RepositoryPermissions {
391 permission: RepoPermission,
392 }
393
394 self.client
395 .get::<RepositoryPermissions, _, _>(
396 format!(
397 "/repos/{ORG}/{REPO}/collaborators/{user}/permission",
398 user = login.as_str()
399 ),
400 None::<&()>,
401 )
402 .await
403 .map(|response| {
404 matches!(
405 response.permission,
406 RepoPermission::Write | RepoPermission::Admin
407 )
408 })
409 .map_err(Into::into)
410 }
411
412 async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> {
413 if self
414 .get_filtered(
415 self.client
416 .issues(ORG, REPO)
417 .list_labels_for_issue(pr_number)
418 .per_page(PAGE_SIZE)
419 .send()
420 .await?,
421 |pr_label| pr_label.name == label,
422 )
423 .await
424 .is_ok_and(|l| l.is_empty())
425 {
426 self.client
427 .issues(ORG, REPO)
428 .add_labels(pr_number, &[label.to_owned()])
429 .await?;
430 }
431
432 Ok(())
433 }
434 }
435}
436
437#[cfg(feature = "octo-client")]
438pub use octo_client::OctocrabClient;