1pub mod billing;
2pub mod contributors;
3pub mod events;
4pub mod extensions;
5pub mod ips_file;
6pub mod slack;
7
8use crate::api::events::SnowflakeRow;
9use crate::{
10 auth,
11 db::{User, UserId},
12 rpc, AppState, Error, Result,
13};
14use anyhow::anyhow;
15use axum::{
16 body::Body,
17 extract::{Path, Query},
18 headers::Header,
19 http::{self, HeaderName, Request, StatusCode},
20 middleware::{self, Next},
21 response::IntoResponse,
22 routing::{get, post},
23 Extension, Json, Router,
24};
25use axum_extra::response::ErasedJson;
26use serde::{Deserialize, Serialize};
27use std::sync::{Arc, OnceLock};
28use tower::ServiceBuilder;
29
30pub use extensions::fetch_extensions_from_blob_store_periodically;
31
32pub struct CloudflareIpCountryHeader(String);
33
34impl Header for CloudflareIpCountryHeader {
35 fn name() -> &'static HeaderName {
36 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
37 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
38 }
39
40 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
41 where
42 Self: Sized,
43 I: Iterator<Item = &'i axum::http::HeaderValue>,
44 {
45 let country_code = values
46 .next()
47 .ok_or_else(axum::headers::Error::invalid)?
48 .to_str()
49 .map_err(|_| axum::headers::Error::invalid())?;
50
51 Ok(Self(country_code.to_string()))
52 }
53
54 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
55 unimplemented!()
56 }
57}
58
59impl std::fmt::Display for CloudflareIpCountryHeader {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65pub struct SystemIdHeader(String);
66
67impl Header for SystemIdHeader {
68 fn name() -> &'static HeaderName {
69 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
70 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
71 }
72
73 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
74 where
75 Self: Sized,
76 I: Iterator<Item = &'i axum::http::HeaderValue>,
77 {
78 let system_id = values
79 .next()
80 .ok_or_else(axum::headers::Error::invalid)?
81 .to_str()
82 .map_err(|_| axum::headers::Error::invalid())?;
83
84 Ok(Self(system_id.to_string()))
85 }
86
87 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
88 unimplemented!()
89 }
90}
91
92impl std::fmt::Display for SystemIdHeader {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 write!(f, "{}", self.0)
95 }
96}
97
98pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
99 Router::new()
100 .route("/user", get(get_authenticated_user))
101 .route("/users/:id/access_tokens", post(create_access_token))
102 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
103 .route("/snowflake/events", post(write_snowflake_event))
104 .merge(billing::router())
105 .merge(contributors::router())
106 .layer(
107 ServiceBuilder::new()
108 .layer(Extension(rpc_server))
109 .layer(middleware::from_fn(validate_api_token)),
110 )
111}
112
113pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
114 let token = req
115 .headers()
116 .get(http::header::AUTHORIZATION)
117 .and_then(|header| header.to_str().ok())
118 .ok_or_else(|| {
119 Error::http(
120 StatusCode::BAD_REQUEST,
121 "missing authorization header".to_string(),
122 )
123 })?
124 .strip_prefix("token ")
125 .ok_or_else(|| {
126 Error::http(
127 StatusCode::BAD_REQUEST,
128 "invalid authorization header".to_string(),
129 )
130 })?;
131
132 let state = req.extensions().get::<Arc<AppState>>().unwrap();
133
134 if token != state.config.api_token {
135 Err(Error::http(
136 StatusCode::UNAUTHORIZED,
137 "invalid authorization token".to_string(),
138 ))?
139 }
140
141 Ok::<_, Error>(next.run(req).await)
142}
143
144#[derive(Debug, Deserialize)]
145struct AuthenticatedUserParams {
146 github_user_id: i32,
147 github_login: String,
148 github_email: Option<String>,
149 github_name: Option<String>,
150 github_user_created_at: chrono::DateTime<chrono::Utc>,
151}
152
153#[derive(Debug, Serialize)]
154struct AuthenticatedUserResponse {
155 user: User,
156 metrics_id: String,
157}
158
159async fn get_authenticated_user(
160 Query(params): Query<AuthenticatedUserParams>,
161 Extension(app): Extension<Arc<AppState>>,
162) -> Result<Json<AuthenticatedUserResponse>> {
163 let initial_channel_id = app.config.auto_join_channel_id;
164
165 let user = app
166 .db
167 .get_or_create_user_by_github_account(
168 ¶ms.github_login,
169 params.github_user_id,
170 params.github_email.as_deref(),
171 params.github_name.as_deref(),
172 params.github_user_created_at,
173 initial_channel_id,
174 )
175 .await?;
176 let metrics_id = app.db.get_user_metrics_id(user.id).await?;
177 Ok(Json(AuthenticatedUserResponse { user, metrics_id }))
178}
179
180#[derive(Deserialize, Debug)]
181struct CreateUserParams {
182 github_user_id: i32,
183 github_login: String,
184 email_address: String,
185 email_confirmation_code: Option<String>,
186 #[serde(default)]
187 admin: bool,
188 #[serde(default)]
189 invite_count: i32,
190}
191
192async fn get_rpc_server_snapshot(
193 Extension(rpc_server): Extension<Arc<rpc::Server>>,
194) -> Result<ErasedJson> {
195 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
196}
197
198#[derive(Deserialize)]
199struct CreateAccessTokenQueryParams {
200 public_key: String,
201 impersonate: Option<String>,
202}
203
204#[derive(Serialize)]
205struct CreateAccessTokenResponse {
206 user_id: UserId,
207 encrypted_access_token: String,
208}
209
210async fn create_access_token(
211 Path(user_id): Path<UserId>,
212 Query(params): Query<CreateAccessTokenQueryParams>,
213 Extension(app): Extension<Arc<AppState>>,
214) -> Result<Json<CreateAccessTokenResponse>> {
215 let user = app
216 .db
217 .get_user_by_id(user_id)
218 .await?
219 .ok_or_else(|| anyhow!("user not found"))?;
220
221 let mut impersonated_user_id = None;
222 if let Some(impersonate) = params.impersonate {
223 if user.admin {
224 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
225 impersonated_user_id = Some(impersonated_user.id);
226 } else {
227 return Err(Error::http(
228 StatusCode::UNPROCESSABLE_ENTITY,
229 format!("user {impersonate} does not exist"),
230 ));
231 }
232 } else {
233 return Err(Error::http(
234 StatusCode::UNAUTHORIZED,
235 "you do not have permission to impersonate other users".to_string(),
236 ));
237 }
238 }
239
240 let access_token =
241 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
242 let encrypted_access_token =
243 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
244
245 Ok(Json(CreateAccessTokenResponse {
246 user_id: impersonated_user_id.unwrap_or(user_id),
247 encrypted_access_token,
248 }))
249}
250
251/// An endpoint that writes a Snowflake event to our event stream.
252///
253/// This endpoint is exposed such that other internal services can write
254/// telemetry events without needing to talk to AWS Kinesis directly.
255async fn write_snowflake_event(
256 Extension(app): Extension<Arc<AppState>>,
257 Json(event): Json<SnowflakeRow>,
258) -> Result<()> {
259 let kinesis_client = app.kinesis_client.clone();
260 let kinesis_stream = app.config.kinesis_stream.clone();
261
262 event.write(&kinesis_client, &kinesis_stream).await?;
263
264 Ok(())
265}