1pub mod contributors;
2pub mod events;
3pub mod extensions;
4pub mod ips_file;
5pub mod slack;
6
7use crate::{AppState, Error, Result, auth, db::UserId, rpc};
8use anyhow::Context as _;
9use axum::{
10 Extension, Json, Router,
11 body::Body,
12 extract::{Path, Query},
13 headers::Header,
14 http::{self, HeaderName, Request, StatusCode},
15 middleware::{self, Next},
16 response::IntoResponse,
17 routing::{get, post},
18};
19use axum_extra::response::ErasedJson;
20use serde::{Deserialize, Serialize};
21use std::sync::{Arc, OnceLock};
22use tower::ServiceBuilder;
23
24pub use extensions::fetch_extensions_from_blob_store_periodically;
25
26pub struct CloudflareIpCountryHeader(String);
27
28impl Header for CloudflareIpCountryHeader {
29 fn name() -> &'static HeaderName {
30 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
31 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
32 }
33
34 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
35 where
36 Self: Sized,
37 I: Iterator<Item = &'i axum::http::HeaderValue>,
38 {
39 let country_code = values
40 .next()
41 .ok_or_else(axum::headers::Error::invalid)?
42 .to_str()
43 .map_err(|_| axum::headers::Error::invalid())?;
44
45 Ok(Self(country_code.to_string()))
46 }
47
48 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
49 unimplemented!()
50 }
51}
52
53impl std::fmt::Display for CloudflareIpCountryHeader {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(f, "{}", self.0)
56 }
57}
58
59pub struct SystemIdHeader(String);
60
61impl Header for SystemIdHeader {
62 fn name() -> &'static HeaderName {
63 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
64 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
65 }
66
67 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
68 where
69 Self: Sized,
70 I: Iterator<Item = &'i axum::http::HeaderValue>,
71 {
72 let system_id = values
73 .next()
74 .ok_or_else(axum::headers::Error::invalid)?
75 .to_str()
76 .map_err(|_| axum::headers::Error::invalid())?;
77
78 Ok(Self(system_id.to_string()))
79 }
80
81 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
82 unimplemented!()
83 }
84}
85
86impl std::fmt::Display for SystemIdHeader {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
93 Router::new()
94 .route("/users/:id/access_tokens", post(create_access_token))
95 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
96 .merge(contributors::router())
97 .layer(
98 ServiceBuilder::new()
99 .layer(Extension(rpc_server))
100 .layer(middleware::from_fn(validate_api_token)),
101 )
102}
103
104pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
105 let token = req
106 .headers()
107 .get(http::header::AUTHORIZATION)
108 .and_then(|header| header.to_str().ok())
109 .ok_or_else(|| {
110 Error::http(
111 StatusCode::BAD_REQUEST,
112 "missing authorization header".to_string(),
113 )
114 })?
115 .strip_prefix("token ")
116 .ok_or_else(|| {
117 Error::http(
118 StatusCode::BAD_REQUEST,
119 "invalid authorization header".to_string(),
120 )
121 })?;
122
123 let state = req.extensions().get::<Arc<AppState>>().unwrap();
124
125 if token != state.config.api_token {
126 Err(Error::http(
127 StatusCode::UNAUTHORIZED,
128 "invalid authorization token".to_string(),
129 ))?
130 }
131
132 Ok::<_, Error>(next.run(req).await)
133}
134
135async fn get_rpc_server_snapshot(
136 Extension(rpc_server): Extension<Arc<rpc::Server>>,
137) -> Result<ErasedJson> {
138 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
139}
140
141#[derive(Deserialize)]
142struct CreateAccessTokenQueryParams {
143 public_key: String,
144 impersonate: Option<String>,
145}
146
147#[derive(Serialize)]
148struct CreateAccessTokenResponse {
149 user_id: UserId,
150 encrypted_access_token: String,
151}
152
153async fn create_access_token(
154 Path(user_id): Path<UserId>,
155 Query(params): Query<CreateAccessTokenQueryParams>,
156 Extension(app): Extension<Arc<AppState>>,
157) -> Result<Json<CreateAccessTokenResponse>> {
158 let user = app
159 .db
160 .get_user_by_id(user_id)
161 .await?
162 .context("user not found")?;
163
164 let mut impersonated_user_id = None;
165 if let Some(impersonate) = params.impersonate {
166 if user.admin {
167 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
168 impersonated_user_id = Some(impersonated_user.id);
169 } else {
170 return Err(Error::http(
171 StatusCode::UNPROCESSABLE_ENTITY,
172 format!("user {impersonate} does not exist"),
173 ));
174 }
175 } else {
176 return Err(Error::http(
177 StatusCode::UNAUTHORIZED,
178 "you do not have permission to impersonate other users".to_string(),
179 ));
180 }
181 }
182
183 let access_token =
184 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
185 let encrypted_access_token =
186 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
187
188 Ok(Json(CreateAccessTokenResponse {
189 user_id: impersonated_user_id.unwrap_or(user_id),
190 encrypted_access_token,
191 }))
192}