1pub mod billing;
2pub mod contributors;
3pub mod events;
4pub mod extensions;
5pub mod ips_file;
6pub mod slack;
7
8use crate::db::Database;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{User, UserId},
12 rpc,
13};
14use anyhow::Context as _;
15use axum::{
16 Extension, Json, Router,
17 body::Body,
18 extract::{Path, Query},
19 headers::Header,
20 http::{self, HeaderName, Request, StatusCode},
21 middleware::{self, Next},
22 response::IntoResponse,
23 routing::{get, post},
24};
25use axum_extra::response::ErasedJson;
26use serde::{Deserialize, Serialize};
27use std::sync::{Arc, OnceLock};
28use tower::ServiceBuilder;
29
30pub use extensions::fetch_extensions_from_blob_store_periodically;
31
32pub struct CloudflareIpCountryHeader(String);
33
34impl Header for CloudflareIpCountryHeader {
35 fn name() -> &'static HeaderName {
36 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
37 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
38 }
39
40 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
41 where
42 Self: Sized,
43 I: Iterator<Item = &'i axum::http::HeaderValue>,
44 {
45 let country_code = values
46 .next()
47 .ok_or_else(axum::headers::Error::invalid)?
48 .to_str()
49 .map_err(|_| axum::headers::Error::invalid())?;
50
51 Ok(Self(country_code.to_string()))
52 }
53
54 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
55 unimplemented!()
56 }
57}
58
59impl std::fmt::Display for CloudflareIpCountryHeader {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65pub struct SystemIdHeader(String);
66
67impl Header for SystemIdHeader {
68 fn name() -> &'static HeaderName {
69 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
70 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
71 }
72
73 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
74 where
75 Self: Sized,
76 I: Iterator<Item = &'i axum::http::HeaderValue>,
77 {
78 let system_id = values
79 .next()
80 .ok_or_else(axum::headers::Error::invalid)?
81 .to_str()
82 .map_err(|_| axum::headers::Error::invalid())?;
83
84 Ok(Self(system_id.to_string()))
85 }
86
87 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
88 unimplemented!()
89 }
90}
91
92impl std::fmt::Display for SystemIdHeader {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 write!(f, "{}", self.0)
95 }
96}
97
98pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
99 Router::new()
100 .route("/users/look_up", get(look_up_user))
101 .route("/users/:id/access_tokens", post(create_access_token))
102 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
103 .merge(contributors::router())
104 .layer(
105 ServiceBuilder::new()
106 .layer(Extension(rpc_server))
107 .layer(middleware::from_fn(validate_api_token)),
108 )
109}
110
111pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
112 let token = req
113 .headers()
114 .get(http::header::AUTHORIZATION)
115 .and_then(|header| header.to_str().ok())
116 .ok_or_else(|| {
117 Error::http(
118 StatusCode::BAD_REQUEST,
119 "missing authorization header".to_string(),
120 )
121 })?
122 .strip_prefix("token ")
123 .ok_or_else(|| {
124 Error::http(
125 StatusCode::BAD_REQUEST,
126 "invalid authorization header".to_string(),
127 )
128 })?;
129
130 let state = req.extensions().get::<Arc<AppState>>().unwrap();
131
132 if token != state.config.api_token {
133 Err(Error::http(
134 StatusCode::UNAUTHORIZED,
135 "invalid authorization token".to_string(),
136 ))?
137 }
138
139 Ok::<_, Error>(next.run(req).await)
140}
141
142#[derive(Debug, Deserialize)]
143struct LookUpUserParams {
144 identifier: String,
145}
146
147#[derive(Debug, Serialize)]
148struct LookUpUserResponse {
149 user: Option<User>,
150}
151
152async fn look_up_user(
153 Query(params): Query<LookUpUserParams>,
154 Extension(app): Extension<Arc<AppState>>,
155) -> Result<Json<LookUpUserResponse>> {
156 let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?;
157 let user = if let Some(user) = user {
158 match user {
159 UserOrId::User(user) => Some(user),
160 UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
161 }
162 } else {
163 None
164 };
165
166 Ok(Json(LookUpUserResponse { user }))
167}
168
169enum UserOrId {
170 User(User),
171 Id(UserId),
172}
173
174async fn resolve_identifier_to_user(
175 db: &Arc<Database>,
176 identifier: &str,
177) -> Result<Option<UserOrId>> {
178 if let Some(identifier) = identifier.parse::<i32>().ok() {
179 let user = db.get_user_by_id(UserId(identifier)).await?;
180
181 return Ok(user.map(UserOrId::User));
182 }
183
184 if identifier.starts_with("cus_") {
185 let billing_customer = db
186 .get_billing_customer_by_stripe_customer_id(&identifier)
187 .await?;
188
189 return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
190 }
191
192 if identifier.starts_with("sub_") {
193 let billing_subscription = db
194 .get_billing_subscription_by_stripe_subscription_id(&identifier)
195 .await?;
196
197 if let Some(billing_subscription) = billing_subscription {
198 let billing_customer = db
199 .get_billing_customer_by_id(billing_subscription.billing_customer_id)
200 .await?;
201
202 return Ok(
203 billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
204 );
205 } else {
206 return Ok(None);
207 }
208 }
209
210 if identifier.contains('@') {
211 let user = db.get_user_by_email(identifier).await?;
212
213 return Ok(user.map(UserOrId::User));
214 }
215
216 if let Some(user) = db.get_user_by_github_login(identifier).await? {
217 return Ok(Some(UserOrId::User(user)));
218 }
219
220 Ok(None)
221}
222
223#[derive(Deserialize, Debug)]
224struct CreateUserParams {
225 github_user_id: i32,
226 github_login: String,
227 email_address: String,
228 email_confirmation_code: Option<String>,
229 #[serde(default)]
230 admin: bool,
231 #[serde(default)]
232 invite_count: i32,
233}
234
235async fn get_rpc_server_snapshot(
236 Extension(rpc_server): Extension<Arc<rpc::Server>>,
237) -> Result<ErasedJson> {
238 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
239}
240
241#[derive(Deserialize)]
242struct CreateAccessTokenQueryParams {
243 public_key: String,
244 impersonate: Option<String>,
245}
246
247#[derive(Serialize)]
248struct CreateAccessTokenResponse {
249 user_id: UserId,
250 encrypted_access_token: String,
251}
252
253async fn create_access_token(
254 Path(user_id): Path<UserId>,
255 Query(params): Query<CreateAccessTokenQueryParams>,
256 Extension(app): Extension<Arc<AppState>>,
257) -> Result<Json<CreateAccessTokenResponse>> {
258 let user = app
259 .db
260 .get_user_by_id(user_id)
261 .await?
262 .context("user not found")?;
263
264 let mut impersonated_user_id = None;
265 if let Some(impersonate) = params.impersonate {
266 if user.admin {
267 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
268 impersonated_user_id = Some(impersonated_user.id);
269 } else {
270 return Err(Error::http(
271 StatusCode::UNPROCESSABLE_ENTITY,
272 format!("user {impersonate} does not exist"),
273 ));
274 }
275 } else {
276 return Err(Error::http(
277 StatusCode::UNAUTHORIZED,
278 "you do not have permission to impersonate other users".to_string(),
279 ));
280 }
281 }
282
283 let access_token =
284 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
285 let encrypted_access_token =
286 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
287
288 Ok(Json(CreateAccessTokenResponse {
289 user_id: impersonated_user_id.unwrap_or(user_id),
290 encrypted_access_token,
291 }))
292}