1pub mod billing;
2pub mod contributors;
3pub mod events;
4pub mod extensions;
5pub mod ips_file;
6pub mod slack;
7
8use crate::{
9 AppState, Error, Result, auth,
10 db::{User, UserId},
11 rpc,
12};
13use anyhow::anyhow;
14use axum::{
15 Extension, Json, Router,
16 body::Body,
17 extract::{Path, Query},
18 headers::Header,
19 http::{self, HeaderName, Request, StatusCode},
20 middleware::{self, Next},
21 response::IntoResponse,
22 routing::{get, post},
23};
24use axum_extra::response::ErasedJson;
25use serde::{Deserialize, Serialize};
26use std::sync::{Arc, OnceLock};
27use tower::ServiceBuilder;
28
29pub use extensions::fetch_extensions_from_blob_store_periodically;
30
31pub struct CloudflareIpCountryHeader(String);
32
33impl Header for CloudflareIpCountryHeader {
34 fn name() -> &'static HeaderName {
35 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
36 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
37 }
38
39 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
40 where
41 Self: Sized,
42 I: Iterator<Item = &'i axum::http::HeaderValue>,
43 {
44 let country_code = values
45 .next()
46 .ok_or_else(axum::headers::Error::invalid)?
47 .to_str()
48 .map_err(|_| axum::headers::Error::invalid())?;
49
50 Ok(Self(country_code.to_string()))
51 }
52
53 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
54 unimplemented!()
55 }
56}
57
58impl std::fmt::Display for CloudflareIpCountryHeader {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "{}", self.0)
61 }
62}
63
64pub struct SystemIdHeader(String);
65
66impl Header for SystemIdHeader {
67 fn name() -> &'static HeaderName {
68 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
69 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
70 }
71
72 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
73 where
74 Self: Sized,
75 I: Iterator<Item = &'i axum::http::HeaderValue>,
76 {
77 let system_id = values
78 .next()
79 .ok_or_else(axum::headers::Error::invalid)?
80 .to_str()
81 .map_err(|_| axum::headers::Error::invalid())?;
82
83 Ok(Self(system_id.to_string()))
84 }
85
86 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
87 unimplemented!()
88 }
89}
90
91impl std::fmt::Display for SystemIdHeader {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "{}", self.0)
94 }
95}
96
97pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
98 Router::new()
99 .route("/user", get(get_authenticated_user))
100 .route("/users/:id/access_tokens", post(create_access_token))
101 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
102 .merge(billing::router())
103 .merge(contributors::router())
104 .layer(
105 ServiceBuilder::new()
106 .layer(Extension(rpc_server))
107 .layer(middleware::from_fn(validate_api_token)),
108 )
109}
110
111pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
112 let token = req
113 .headers()
114 .get(http::header::AUTHORIZATION)
115 .and_then(|header| header.to_str().ok())
116 .ok_or_else(|| {
117 Error::http(
118 StatusCode::BAD_REQUEST,
119 "missing authorization header".to_string(),
120 )
121 })?
122 .strip_prefix("token ")
123 .ok_or_else(|| {
124 Error::http(
125 StatusCode::BAD_REQUEST,
126 "invalid authorization header".to_string(),
127 )
128 })?;
129
130 let state = req.extensions().get::<Arc<AppState>>().unwrap();
131
132 if token != state.config.api_token {
133 Err(Error::http(
134 StatusCode::UNAUTHORIZED,
135 "invalid authorization token".to_string(),
136 ))?
137 }
138
139 Ok::<_, Error>(next.run(req).await)
140}
141
142#[derive(Debug, Deserialize)]
143struct AuthenticatedUserParams {
144 github_user_id: i32,
145 github_login: String,
146 github_email: Option<String>,
147 github_name: Option<String>,
148 github_user_created_at: chrono::DateTime<chrono::Utc>,
149}
150
151#[derive(Debug, Serialize)]
152struct AuthenticatedUserResponse {
153 user: User,
154 metrics_id: String,
155 feature_flags: Vec<String>,
156}
157
158async fn get_authenticated_user(
159 Query(params): Query<AuthenticatedUserParams>,
160 Extension(app): Extension<Arc<AppState>>,
161) -> Result<Json<AuthenticatedUserResponse>> {
162 let initial_channel_id = app.config.auto_join_channel_id;
163
164 let user = app
165 .db
166 .get_or_create_user_by_github_account(
167 ¶ms.github_login,
168 params.github_user_id,
169 params.github_email.as_deref(),
170 params.github_name.as_deref(),
171 params.github_user_created_at,
172 initial_channel_id,
173 )
174 .await?;
175 let metrics_id = app.db.get_user_metrics_id(user.id).await?;
176 let feature_flags = app.db.get_user_flags(user.id).await?;
177 Ok(Json(AuthenticatedUserResponse {
178 user,
179 metrics_id,
180 feature_flags,
181 }))
182}
183
184#[derive(Deserialize, Debug)]
185struct CreateUserParams {
186 github_user_id: i32,
187 github_login: String,
188 email_address: String,
189 email_confirmation_code: Option<String>,
190 #[serde(default)]
191 admin: bool,
192 #[serde(default)]
193 invite_count: i32,
194}
195
196async fn get_rpc_server_snapshot(
197 Extension(rpc_server): Extension<Arc<rpc::Server>>,
198) -> Result<ErasedJson> {
199 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
200}
201
202#[derive(Deserialize)]
203struct CreateAccessTokenQueryParams {
204 public_key: String,
205 impersonate: Option<String>,
206}
207
208#[derive(Serialize)]
209struct CreateAccessTokenResponse {
210 user_id: UserId,
211 encrypted_access_token: String,
212}
213
214async fn create_access_token(
215 Path(user_id): Path<UserId>,
216 Query(params): Query<CreateAccessTokenQueryParams>,
217 Extension(app): Extension<Arc<AppState>>,
218) -> Result<Json<CreateAccessTokenResponse>> {
219 let user = app
220 .db
221 .get_user_by_id(user_id)
222 .await?
223 .ok_or_else(|| anyhow!("user not found"))?;
224
225 let mut impersonated_user_id = None;
226 if let Some(impersonate) = params.impersonate {
227 if user.admin {
228 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
229 impersonated_user_id = Some(impersonated_user.id);
230 } else {
231 return Err(Error::http(
232 StatusCode::UNPROCESSABLE_ENTITY,
233 format!("user {impersonate} does not exist"),
234 ));
235 }
236 } else {
237 return Err(Error::http(
238 StatusCode::UNAUTHORIZED,
239 "you do not have permission to impersonate other users".to_string(),
240 ));
241 }
242 }
243
244 let access_token =
245 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
246 let encrypted_access_token =
247 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
248
249 Ok(Json(CreateAccessTokenResponse {
250 user_id: impersonated_user_id.unwrap_or(user_id),
251 encrypted_access_token,
252 }))
253}