1use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
2
3use anyhow::Result;
4use derive_more::Deref;
5use serde::Deserialize;
6
7use crate::git::CommitSha;
8
9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
10
11#[derive(Debug, Clone)]
12pub struct GitHubUser {
13 pub login: String,
14}
15
16#[derive(Debug, Clone)]
17pub struct PullRequestData {
18 pub number: u64,
19 pub user: Option<GitHubUser>,
20 pub merged_by: Option<GitHubUser>,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ReviewState {
25 Approved,
26 Other,
27}
28
29#[derive(Debug, Clone)]
30pub struct PullRequestReview {
31 pub user: Option<GitHubUser>,
32 pub state: Option<ReviewState>,
33 pub body: Option<String>,
34}
35
36impl PullRequestReview {
37 pub fn with_body(self, body: impl ToString) -> Self {
38 Self {
39 body: Some(body.to_string()),
40 ..self
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
46pub struct PullRequestComment {
47 pub user: GitHubUser,
48 pub body: Option<String>,
49}
50
51#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
52pub struct GithubLogin {
53 login: String,
54}
55
56impl GithubLogin {
57 pub fn new(login: String) -> Self {
58 Self { login }
59 }
60}
61
62impl fmt::Display for GithubLogin {
63 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
64 write!(formatter, "@{}", self.login)
65 }
66}
67
68#[derive(Debug, Deserialize, Clone)]
69pub struct CommitAuthor {
70 name: String,
71 email: String,
72 user: Option<GithubLogin>,
73}
74
75impl CommitAuthor {
76 pub(crate) fn user(&self) -> Option<&GithubLogin> {
77 self.user.as_ref()
78 }
79}
80
81impl PartialEq for CommitAuthor {
82 fn eq(&self, other: &Self) -> bool {
83 self.user.as_ref().zip(other.user.as_ref()).map_or_else(
84 || self.email == other.email || self.name == other.name,
85 |(l, r)| l == r,
86 )
87 }
88}
89
90impl fmt::Display for CommitAuthor {
91 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
92 match self.user.as_ref() {
93 Some(user) => write!(formatter, "{} ({user})", self.name),
94 None => write!(formatter, "{} ({})", self.name, self.email),
95 }
96 }
97}
98
99#[derive(Debug, Deserialize)]
100pub struct CommitAuthors {
101 #[serde(rename = "author")]
102 primary_author: CommitAuthor,
103 #[serde(rename = "authors")]
104 co_authors: Vec<CommitAuthor>,
105}
106
107impl CommitAuthors {
108 pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
109 self.co_authors.is_empty().not().then(|| {
110 self.co_authors
111 .iter()
112 .filter(|co_author| *co_author != &self.primary_author)
113 })
114 }
115}
116
117#[derive(Debug, Deserialize, Deref)]
118pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
119
120#[async_trait::async_trait(?Send)]
121pub trait GitHubApiClient {
122 async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
123 async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
124 async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
125 async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
126 async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
127 async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool>;
128 async fn actor_has_repository_write_permission(
129 &self,
130 login: &GithubLogin,
131 ) -> anyhow::Result<bool> {
132 Ok(self.check_org_membership(login).await?
133 || self.check_repo_write_permission(login).await?)
134 }
135 async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>;
136}
137
138#[derive(Deref)]
139pub struct GitHubClient {
140 api: Rc<dyn GitHubApiClient>,
141}
142
143impl GitHubClient {
144 pub fn new(api: Rc<dyn GitHubApiClient>) -> Self {
145 Self { api }
146 }
147
148 #[cfg(feature = "octo-client")]
149 pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
150 let client = OctocrabClient::new(app_id, app_private_key).await?;
151 Ok(Self::new(Rc::new(client)))
152 }
153}
154
155#[cfg(feature = "octo-client")]
156mod octo_client {
157 use anyhow::{Context, Result};
158 use futures::TryStreamExt as _;
159 use itertools::Itertools;
160 use jsonwebtoken::EncodingKey;
161 use octocrab::{
162 Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
163 service::middleware::cache::mem::InMemoryCache,
164 };
165 use serde::de::DeserializeOwned;
166 use tokio::pin;
167
168 use crate::git::CommitSha;
169
170 use super::{
171 AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
172 PullRequestData, PullRequestReview, ReviewState,
173 };
174
175 const PAGE_SIZE: u8 = 100;
176 const ORG: &str = "zed-industries";
177 const REPO: &str = "zed";
178
179 pub struct OctocrabClient {
180 client: Octocrab,
181 }
182
183 impl OctocrabClient {
184 pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
185 let octocrab = Octocrab::builder()
186 .cache(InMemoryCache::new())
187 .app(
188 app_id.into(),
189 EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
190 )
191 .build()?;
192
193 let installations = octocrab
194 .apps()
195 .installations()
196 .send()
197 .await
198 .context("Failed to fetch installations")?
199 .take_items();
200
201 let installation_id = installations
202 .into_iter()
203 .find(|installation| installation.account.login == ORG)
204 .context("Could not find Zed repository in installations")?
205 .id;
206
207 let client = octocrab.installation(installation_id)?;
208 Ok(Self { client })
209 }
210
211 fn build_co_authors_query<'a>(shas: impl IntoIterator<Item = &'a CommitSha>) -> String {
212 const FRAGMENT: &str = r#"
213 ... on Commit {
214 author {
215 name
216 email
217 user { login }
218 }
219 authors(first: 10) {
220 nodes {
221 name
222 email
223 user { login }
224 }
225 }
226 }
227 "#;
228
229 let objects: String = shas
230 .into_iter()
231 .map(|commit_sha| {
232 format!(
233 "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
234 sha = **commit_sha
235 )
236 })
237 .join("\n");
238
239 format!("{{ repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects} }} }}")
240 .replace("\n", "")
241 }
242
243 async fn graphql<R: octocrab::FromResponse>(
244 &self,
245 query: &serde_json::Value,
246 ) -> octocrab::Result<R> {
247 self.client.graphql(query).await
248 }
249
250 async fn get_all<T: DeserializeOwned + 'static>(
251 &self,
252 page: Page<T>,
253 ) -> octocrab::Result<Vec<T>> {
254 self.get_filtered(page, |_| true).await
255 }
256
257 async fn get_filtered<T: DeserializeOwned + 'static>(
258 &self,
259 page: Page<T>,
260 predicate: impl Fn(&T) -> bool,
261 ) -> octocrab::Result<Vec<T>> {
262 let stream = page.into_stream(&self.client);
263 pin!(stream);
264
265 let mut results = Vec::new();
266
267 while let Some(item) = stream.try_next().await?
268 && predicate(&item)
269 {
270 results.push(item);
271 }
272
273 Ok(results)
274 }
275 }
276
277 #[async_trait::async_trait(?Send)]
278 impl GitHubApiClient for OctocrabClient {
279 async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
280 let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
281 Ok(PullRequestData {
282 number: pr.number,
283 user: pr.user.map(|user| GitHubUser { login: user.login }),
284 merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
285 })
286 }
287
288 async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
289 let page = self
290 .client
291 .pulls(ORG, REPO)
292 .list_reviews(pr_number)
293 .per_page(PAGE_SIZE)
294 .send()
295 .await?;
296
297 let reviews = self.get_all(page).await?;
298
299 Ok(reviews
300 .into_iter()
301 .map(|review| PullRequestReview {
302 user: review.user.map(|user| GitHubUser { login: user.login }),
303 state: review.state.map(|state| match state {
304 OctocrabReviewState::Approved => ReviewState::Approved,
305 _ => ReviewState::Other,
306 }),
307 body: review.body,
308 })
309 .collect())
310 }
311
312 async fn get_pull_request_comments(
313 &self,
314 pr_number: u64,
315 ) -> Result<Vec<PullRequestComment>> {
316 let page = self
317 .client
318 .issues(ORG, REPO)
319 .list_comments(pr_number)
320 .per_page(PAGE_SIZE)
321 .send()
322 .await?;
323
324 let comments = self.get_all(page).await?;
325
326 Ok(comments
327 .into_iter()
328 .map(|comment| PullRequestComment {
329 user: GitHubUser {
330 login: comment.user.login,
331 },
332 body: comment.body,
333 })
334 .collect())
335 }
336
337 async fn get_commit_authors(
338 &self,
339 commit_shas: &[&CommitSha],
340 ) -> Result<AuthorsForCommits> {
341 let query = Self::build_co_authors_query(commit_shas.iter().copied());
342 let query = serde_json::json!({ "query": query });
343 let mut response = self.graphql::<serde_json::Value>(&query).await?;
344
345 response
346 .get_mut("data")
347 .and_then(|data| data.get_mut("repository"))
348 .and_then(|repo| repo.as_object_mut())
349 .ok_or_else(|| anyhow::anyhow!("Unexpected response format!"))
350 .and_then(|commit_data| {
351 let mut response_map = serde_json::Map::with_capacity(commit_data.len());
352
353 for (key, value) in commit_data.iter_mut() {
354 let key_without_prefix = key.strip_prefix("commit").unwrap_or(key);
355 if let Some(authors) = value.get_mut("authors") {
356 if let Some(nodes) = authors.get("nodes") {
357 *authors = nodes.clone();
358 }
359 }
360
361 response_map.insert(key_without_prefix.to_owned(), value.clone());
362 }
363
364 serde_json::from_value(serde_json::Value::Object(response_map))
365 .context("Failed to deserialize commit authors")
366 })
367 }
368
369 async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
370 let page = self
371 .client
372 .orgs(ORG)
373 .list_members()
374 .per_page(PAGE_SIZE)
375 .send()
376 .await?;
377
378 let members = self.get_all(page).await?;
379
380 Ok(members
381 .into_iter()
382 .any(|member| member.login == login.as_str()))
383 }
384
385 async fn check_repo_write_permission(&self, login: &GithubLogin) -> Result<bool> {
386 // TODO: octocrab fails to deserialize the permission response and
387 // does not adhere to the scheme laid out at
388 // https://docs.github.com/en/rest/collaborators/collaborators?apiVersion=2026-03-10#get-repository-permissions-for-a-user
389
390 #[derive(serde::Deserialize)]
391 #[serde(rename_all = "lowercase")]
392 enum RepoPermission {
393 Admin,
394 Write,
395 Read,
396 #[serde(other)]
397 Other,
398 }
399
400 #[derive(serde::Deserialize)]
401 struct RepositoryPermissions {
402 permission: RepoPermission,
403 }
404
405 self.client
406 .get::<RepositoryPermissions, _, _>(
407 format!(
408 "/repos/{ORG}/{REPO}/collaborators/{user}/permission",
409 user = login.as_str()
410 ),
411 None::<&()>,
412 )
413 .await
414 .map(|response| {
415 matches!(
416 response.permission,
417 RepoPermission::Write | RepoPermission::Admin
418 )
419 })
420 .map_err(Into::into)
421 }
422
423 async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> {
424 if self
425 .get_filtered(
426 self.client
427 .issues(ORG, REPO)
428 .list_labels_for_issue(pr_number)
429 .per_page(PAGE_SIZE)
430 .send()
431 .await?,
432 |pr_label| pr_label.name == label,
433 )
434 .await
435 .is_ok_and(|l| l.is_empty())
436 {
437 self.client
438 .issues(ORG, REPO)
439 .add_labels(pr_number, &[label.to_owned()])
440 .await?;
441 }
442
443 Ok(())
444 }
445 }
446}
447
448#[cfg(feature = "octo-client")]
449pub use octo_client::OctocrabClient;