1pub mod billing;
2pub mod contributors;
3pub mod events;
4pub mod extensions;
5pub mod ips_file;
6pub mod slack;
7
8use crate::db::Database;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{User, UserId},
12 rpc,
13};
14use anyhow::Context as _;
15use axum::{
16 Extension, Json, Router,
17 body::Body,
18 extract::{Path, Query},
19 headers::Header,
20 http::{self, HeaderName, Request, StatusCode},
21 middleware::{self, Next},
22 response::IntoResponse,
23 routing::{get, post},
24};
25use axum_extra::response::ErasedJson;
26use serde::{Deserialize, Serialize};
27use std::sync::{Arc, OnceLock};
28use tower::ServiceBuilder;
29
30pub use extensions::fetch_extensions_from_blob_store_periodically;
31
32pub struct CloudflareIpCountryHeader(String);
33
34impl Header for CloudflareIpCountryHeader {
35 fn name() -> &'static HeaderName {
36 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
37 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
38 }
39
40 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
41 where
42 Self: Sized,
43 I: Iterator<Item = &'i axum::http::HeaderValue>,
44 {
45 let country_code = values
46 .next()
47 .ok_or_else(axum::headers::Error::invalid)?
48 .to_str()
49 .map_err(|_| axum::headers::Error::invalid())?;
50
51 Ok(Self(country_code.to_string()))
52 }
53
54 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
55 unimplemented!()
56 }
57}
58
59impl std::fmt::Display for CloudflareIpCountryHeader {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65pub struct SystemIdHeader(String);
66
67impl Header for SystemIdHeader {
68 fn name() -> &'static HeaderName {
69 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
70 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
71 }
72
73 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
74 where
75 Self: Sized,
76 I: Iterator<Item = &'i axum::http::HeaderValue>,
77 {
78 let system_id = values
79 .next()
80 .ok_or_else(axum::headers::Error::invalid)?
81 .to_str()
82 .map_err(|_| axum::headers::Error::invalid())?;
83
84 Ok(Self(system_id.to_string()))
85 }
86
87 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
88 unimplemented!()
89 }
90}
91
92impl std::fmt::Display for SystemIdHeader {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 write!(f, "{}", self.0)
95 }
96}
97
98pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
99 Router::new()
100 .route("/user", get(get_authenticated_user))
101 .route("/users/look_up", get(look_up_user))
102 .route("/users/:id/access_tokens", post(create_access_token))
103 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
104 .merge(billing::router())
105 .merge(contributors::router())
106 .layer(
107 ServiceBuilder::new()
108 .layer(Extension(rpc_server))
109 .layer(middleware::from_fn(validate_api_token)),
110 )
111}
112
113pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
114 let token = req
115 .headers()
116 .get(http::header::AUTHORIZATION)
117 .and_then(|header| header.to_str().ok())
118 .ok_or_else(|| {
119 Error::http(
120 StatusCode::BAD_REQUEST,
121 "missing authorization header".to_string(),
122 )
123 })?
124 .strip_prefix("token ")
125 .ok_or_else(|| {
126 Error::http(
127 StatusCode::BAD_REQUEST,
128 "invalid authorization header".to_string(),
129 )
130 })?;
131
132 let state = req.extensions().get::<Arc<AppState>>().unwrap();
133
134 if token != state.config.api_token {
135 Err(Error::http(
136 StatusCode::UNAUTHORIZED,
137 "invalid authorization token".to_string(),
138 ))?
139 }
140
141 Ok::<_, Error>(next.run(req).await)
142}
143
144#[derive(Debug, Deserialize)]
145struct AuthenticatedUserParams {
146 github_user_id: i32,
147 github_login: String,
148 github_email: Option<String>,
149 github_name: Option<String>,
150 github_user_created_at: chrono::DateTime<chrono::Utc>,
151}
152
153#[derive(Debug, Serialize)]
154struct AuthenticatedUserResponse {
155 user: User,
156 metrics_id: String,
157 feature_flags: Vec<String>,
158}
159
160async fn get_authenticated_user(
161 Query(params): Query<AuthenticatedUserParams>,
162 Extension(app): Extension<Arc<AppState>>,
163) -> Result<Json<AuthenticatedUserResponse>> {
164 let initial_channel_id = app.config.auto_join_channel_id;
165
166 let user = app
167 .db
168 .get_or_create_user_by_github_account(
169 ¶ms.github_login,
170 params.github_user_id,
171 params.github_email.as_deref(),
172 params.github_name.as_deref(),
173 params.github_user_created_at,
174 initial_channel_id,
175 )
176 .await?;
177 let metrics_id = app.db.get_user_metrics_id(user.id).await?;
178 let feature_flags = app.db.get_user_flags(user.id).await?;
179 Ok(Json(AuthenticatedUserResponse {
180 user,
181 metrics_id,
182 feature_flags,
183 }))
184}
185
186#[derive(Debug, Deserialize)]
187struct LookUpUserParams {
188 identifier: String,
189}
190
191#[derive(Debug, Serialize)]
192struct LookUpUserResponse {
193 user: Option<User>,
194}
195
196async fn look_up_user(
197 Query(params): Query<LookUpUserParams>,
198 Extension(app): Extension<Arc<AppState>>,
199) -> Result<Json<LookUpUserResponse>> {
200 let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?;
201 let user = if let Some(user) = user {
202 match user {
203 UserOrId::User(user) => Some(user),
204 UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
205 }
206 } else {
207 None
208 };
209
210 Ok(Json(LookUpUserResponse { user }))
211}
212
213enum UserOrId {
214 User(User),
215 Id(UserId),
216}
217
218async fn resolve_identifier_to_user(
219 db: &Arc<Database>,
220 identifier: &str,
221) -> Result<Option<UserOrId>> {
222 if let Some(identifier) = identifier.parse::<i32>().ok() {
223 let user = db.get_user_by_id(UserId(identifier)).await?;
224
225 return Ok(user.map(UserOrId::User));
226 }
227
228 if identifier.starts_with("cus_") {
229 let billing_customer = db
230 .get_billing_customer_by_stripe_customer_id(&identifier)
231 .await?;
232
233 return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
234 }
235
236 if identifier.starts_with("sub_") {
237 let billing_subscription = db
238 .get_billing_subscription_by_stripe_subscription_id(&identifier)
239 .await?;
240
241 if let Some(billing_subscription) = billing_subscription {
242 let billing_customer = db
243 .get_billing_customer_by_id(billing_subscription.billing_customer_id)
244 .await?;
245
246 return Ok(
247 billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
248 );
249 } else {
250 return Ok(None);
251 }
252 }
253
254 if identifier.contains('@') {
255 let user = db.get_user_by_email(identifier).await?;
256
257 return Ok(user.map(UserOrId::User));
258 }
259
260 if let Some(user) = db.get_user_by_github_login(identifier).await? {
261 return Ok(Some(UserOrId::User(user)));
262 }
263
264 Ok(None)
265}
266
267#[derive(Deserialize, Debug)]
268struct CreateUserParams {
269 github_user_id: i32,
270 github_login: String,
271 email_address: String,
272 email_confirmation_code: Option<String>,
273 #[serde(default)]
274 admin: bool,
275 #[serde(default)]
276 invite_count: i32,
277}
278
279async fn get_rpc_server_snapshot(
280 Extension(rpc_server): Extension<Arc<rpc::Server>>,
281) -> Result<ErasedJson> {
282 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
283}
284
285#[derive(Deserialize)]
286struct CreateAccessTokenQueryParams {
287 public_key: String,
288 impersonate: Option<String>,
289}
290
291#[derive(Serialize)]
292struct CreateAccessTokenResponse {
293 user_id: UserId,
294 encrypted_access_token: String,
295}
296
297async fn create_access_token(
298 Path(user_id): Path<UserId>,
299 Query(params): Query<CreateAccessTokenQueryParams>,
300 Extension(app): Extension<Arc<AppState>>,
301) -> Result<Json<CreateAccessTokenResponse>> {
302 let user = app
303 .db
304 .get_user_by_id(user_id)
305 .await?
306 .context("user not found")?;
307
308 let mut impersonated_user_id = None;
309 if let Some(impersonate) = params.impersonate {
310 if user.admin {
311 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
312 impersonated_user_id = Some(impersonated_user.id);
313 } else {
314 return Err(Error::http(
315 StatusCode::UNPROCESSABLE_ENTITY,
316 format!("user {impersonate} does not exist"),
317 ));
318 }
319 } else {
320 return Err(Error::http(
321 StatusCode::UNAUTHORIZED,
322 "you do not have permission to impersonate other users".to_string(),
323 ));
324 }
325 }
326
327 let access_token =
328 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
329 let encrypted_access_token =
330 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
331
332 Ok(Json(CreateAccessTokenResponse {
333 user_id: impersonated_user_id.unwrap_or(user_id),
334 encrypted_access_token,
335 }))
336}