1use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
2
3use anyhow::Result;
4use derive_more::Deref;
5use serde::Deserialize;
6
7use crate::git::CommitSha;
8
9pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
10
11#[derive(Debug, Clone)]
12pub struct GitHubUser {
13 pub login: String,
14}
15
16#[derive(Debug, Clone)]
17pub struct PullRequestData {
18 pub number: u64,
19 pub user: Option<GitHubUser>,
20 pub merged_by: Option<GitHubUser>,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ReviewState {
25 Approved,
26 Other,
27}
28
29#[derive(Debug, Clone)]
30pub struct PullRequestReview {
31 pub user: Option<GitHubUser>,
32 pub state: Option<ReviewState>,
33}
34
35#[derive(Debug, Clone)]
36pub struct PullRequestComment {
37 pub user: GitHubUser,
38 pub body: Option<String>,
39}
40
41#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
42pub struct GithubLogin {
43 login: String,
44}
45
46impl GithubLogin {
47 pub(crate) fn new(login: String) -> Self {
48 Self { login }
49 }
50}
51
52impl fmt::Display for GithubLogin {
53 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
54 write!(formatter, "@{}", self.login)
55 }
56}
57
58#[derive(Debug, Deserialize, Clone)]
59pub struct CommitAuthor {
60 name: String,
61 email: String,
62 user: Option<GithubLogin>,
63}
64
65impl CommitAuthor {
66 pub(crate) fn user(&self) -> Option<&GithubLogin> {
67 self.user.as_ref()
68 }
69}
70
71impl PartialEq for CommitAuthor {
72 fn eq(&self, other: &Self) -> bool {
73 self.user.as_ref().zip(other.user.as_ref()).map_or_else(
74 || self.email == other.email || self.name == other.name,
75 |(l, r)| l == r,
76 )
77 }
78}
79
80impl fmt::Display for CommitAuthor {
81 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self.user.as_ref() {
83 Some(user) => write!(formatter, "{} ({user})", self.name),
84 None => write!(formatter, "{} ({})", self.name, self.email),
85 }
86 }
87}
88
89#[derive(Debug, Deserialize)]
90pub struct CommitAuthors {
91 #[serde(rename = "author")]
92 primary_author: CommitAuthor,
93 #[serde(rename = "authors")]
94 co_authors: Vec<CommitAuthor>,
95}
96
97impl CommitAuthors {
98 pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
99 self.co_authors.is_empty().not().then(|| {
100 self.co_authors
101 .iter()
102 .filter(|co_author| *co_author != &self.primary_author)
103 })
104 }
105}
106
107#[derive(Debug, Deserialize, Deref)]
108pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
109
110#[async_trait::async_trait(?Send)]
111pub trait GitHubApiClient {
112 async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
113 async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
114 async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
115 async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
116 async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
117 async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>;
118}
119
120pub struct GitHubClient {
121 api: Rc<dyn GitHubApiClient>,
122}
123
124impl GitHubClient {
125 pub fn new(api: Rc<dyn GitHubApiClient>) -> Self {
126 Self { api }
127 }
128
129 #[cfg(feature = "octo-client")]
130 pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
131 let client = OctocrabClient::new(app_id, app_private_key).await?;
132 Ok(Self::new(Rc::new(client)))
133 }
134
135 pub async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
136 self.api.get_pull_request(pr_number).await
137 }
138
139 pub async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
140 self.api.get_pull_request_reviews(pr_number).await
141 }
142
143 pub async fn get_pull_request_comments(
144 &self,
145 pr_number: u64,
146 ) -> Result<Vec<PullRequestComment>> {
147 self.api.get_pull_request_comments(pr_number).await
148 }
149
150 pub async fn get_commit_authors<'a>(
151 &self,
152 commit_shas: impl IntoIterator<Item = &'a CommitSha>,
153 ) -> Result<AuthorsForCommits> {
154 let shas: Vec<&CommitSha> = commit_shas.into_iter().collect();
155 self.api.get_commit_authors(&shas).await
156 }
157
158 pub async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
159 self.api.check_org_membership(login).await
160 }
161
162 pub async fn add_label_to_pull_request(&self, label: &str, pr_number: u64) -> Result<()> {
163 self.api
164 .ensure_pull_request_has_label(label, pr_number)
165 .await
166 }
167}
168
169#[cfg(feature = "octo-client")]
170mod octo_client {
171 use anyhow::{Context, Result};
172 use futures::TryStreamExt as _;
173 use itertools::Itertools;
174 use jsonwebtoken::EncodingKey;
175 use octocrab::{
176 Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
177 service::middleware::cache::mem::InMemoryCache,
178 };
179 use serde::de::DeserializeOwned;
180 use tokio::pin;
181
182 use crate::git::CommitSha;
183
184 use super::{
185 AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
186 PullRequestData, PullRequestReview, ReviewState,
187 };
188
189 const PAGE_SIZE: u8 = 100;
190 const ORG: &str = "zed-industries";
191 const REPO: &str = "zed";
192
193 pub struct OctocrabClient {
194 client: Octocrab,
195 }
196
197 impl OctocrabClient {
198 pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
199 let octocrab = Octocrab::builder()
200 .cache(InMemoryCache::new())
201 .app(
202 app_id.into(),
203 EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
204 )
205 .build()?;
206
207 let installations = octocrab
208 .apps()
209 .installations()
210 .send()
211 .await
212 .context("Failed to fetch installations")?
213 .take_items();
214
215 let installation_id = installations
216 .into_iter()
217 .find(|installation| installation.account.login == ORG)
218 .context("Could not find Zed repository in installations")?
219 .id;
220
221 let client = octocrab.installation(installation_id)?;
222 Ok(Self { client })
223 }
224
225 fn build_co_authors_query<'a>(shas: impl IntoIterator<Item = &'a CommitSha>) -> String {
226 const FRAGMENT: &str = r#"
227 ... on Commit {
228 author {
229 name
230 email
231 user { login }
232 }
233 authors(first: 10) {
234 nodes {
235 name
236 email
237 user { login }
238 }
239 }
240 }
241 "#;
242
243 let objects: String = shas
244 .into_iter()
245 .map(|commit_sha| {
246 format!(
247 "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
248 sha = **commit_sha
249 )
250 })
251 .join("\n");
252
253 format!("{{ repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects} }} }}")
254 .replace("\n", "")
255 }
256
257 async fn graphql<R: octocrab::FromResponse>(
258 &self,
259 query: &serde_json::Value,
260 ) -> octocrab::Result<R> {
261 self.client.graphql(query).await
262 }
263
264 async fn get_all<T: DeserializeOwned + 'static>(
265 &self,
266 page: Page<T>,
267 ) -> octocrab::Result<Vec<T>> {
268 self.get_filtered(page, |_| true).await
269 }
270
271 async fn get_filtered<T: DeserializeOwned + 'static>(
272 &self,
273 page: Page<T>,
274 predicate: impl Fn(&T) -> bool,
275 ) -> octocrab::Result<Vec<T>> {
276 let stream = page.into_stream(&self.client);
277 pin!(stream);
278
279 let mut results = Vec::new();
280
281 while let Some(item) = stream.try_next().await?
282 && predicate(&item)
283 {
284 results.push(item);
285 }
286
287 Ok(results)
288 }
289 }
290
291 #[async_trait::async_trait(?Send)]
292 impl GitHubApiClient for OctocrabClient {
293 async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
294 let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
295 Ok(PullRequestData {
296 number: pr.number,
297 user: pr.user.map(|user| GitHubUser { login: user.login }),
298 merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
299 })
300 }
301
302 async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
303 let page = self
304 .client
305 .pulls(ORG, REPO)
306 .list_reviews(pr_number)
307 .per_page(PAGE_SIZE)
308 .send()
309 .await?;
310
311 let reviews = self.get_all(page).await?;
312
313 Ok(reviews
314 .into_iter()
315 .map(|review| PullRequestReview {
316 user: review.user.map(|user| GitHubUser { login: user.login }),
317 state: review.state.map(|state| match state {
318 OctocrabReviewState::Approved => ReviewState::Approved,
319 _ => ReviewState::Other,
320 }),
321 })
322 .collect())
323 }
324
325 async fn get_pull_request_comments(
326 &self,
327 pr_number: u64,
328 ) -> Result<Vec<PullRequestComment>> {
329 let page = self
330 .client
331 .issues(ORG, REPO)
332 .list_comments(pr_number)
333 .per_page(PAGE_SIZE)
334 .send()
335 .await?;
336
337 let comments = self.get_all(page).await?;
338
339 Ok(comments
340 .into_iter()
341 .map(|comment| PullRequestComment {
342 user: GitHubUser {
343 login: comment.user.login,
344 },
345 body: comment.body,
346 })
347 .collect())
348 }
349
350 async fn get_commit_authors(
351 &self,
352 commit_shas: &[&CommitSha],
353 ) -> Result<AuthorsForCommits> {
354 let query = Self::build_co_authors_query(commit_shas.iter().copied());
355 let query = serde_json::json!({ "query": query });
356 let mut response = self.graphql::<serde_json::Value>(&query).await?;
357
358 response
359 .get_mut("data")
360 .and_then(|data| data.get_mut("repository"))
361 .and_then(|repo| repo.as_object_mut())
362 .ok_or_else(|| anyhow::anyhow!("Unexpected response format!"))
363 .and_then(|commit_data| {
364 let mut response_map = serde_json::Map::with_capacity(commit_data.len());
365
366 for (key, value) in commit_data.iter_mut() {
367 let key_without_prefix = key.strip_prefix("commit").unwrap_or(key);
368 if let Some(authors) = value.get_mut("authors") {
369 if let Some(nodes) = authors.get("nodes") {
370 *authors = nodes.clone();
371 }
372 }
373
374 response_map.insert(key_without_prefix.to_owned(), value.clone());
375 }
376
377 serde_json::from_value(serde_json::Value::Object(response_map))
378 .context("Failed to deserialize commit authors")
379 })
380 }
381
382 async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
383 let page = self
384 .client
385 .orgs(ORG)
386 .list_members()
387 .per_page(PAGE_SIZE)
388 .send()
389 .await?;
390
391 let members = self.get_all(page).await?;
392
393 Ok(members
394 .into_iter()
395 .any(|member| member.login == login.as_str()))
396 }
397
398 async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> {
399 if self
400 .get_filtered(
401 self.client
402 .issues(ORG, REPO)
403 .list_labels_for_issue(pr_number)
404 .per_page(PAGE_SIZE)
405 .send()
406 .await?,
407 |pr_label| pr_label.name == label,
408 )
409 .await
410 .is_ok_and(|l| l.is_empty())
411 {
412 self.client
413 .issues(ORG, REPO)
414 .add_labels(pr_number, &[label.to_owned()])
415 .await?;
416 }
417
418 Ok(())
419 }
420 }
421}
422
423#[cfg(feature = "octo-client")]
424pub use octo_client::OctocrabClient;