1pub mod events;
2pub mod extensions;
3
4use crate::{AppState, Error, Result, auth, db::UserId, rpc};
5use anyhow::Context as _;
6use axum::{
7 Extension, Json, Router,
8 body::Body,
9 extract::{Path, Query},
10 headers::Header,
11 http::{self, HeaderName, Request, StatusCode},
12 middleware::{self, Next},
13 response::IntoResponse,
14 routing::{get, post},
15};
16use axum_extra::response::ErasedJson;
17use serde::{Deserialize, Serialize};
18use std::sync::{Arc, OnceLock};
19use tower::ServiceBuilder;
20
21pub use extensions::fetch_extensions_from_blob_store_periodically;
22
23pub struct CloudflareIpCountryHeader(String);
24
25impl Header for CloudflareIpCountryHeader {
26 fn name() -> &'static HeaderName {
27 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
28 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
29 }
30
31 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
32 where
33 Self: Sized,
34 I: Iterator<Item = &'i axum::http::HeaderValue>,
35 {
36 let country_code = values
37 .next()
38 .ok_or_else(axum::headers::Error::invalid)?
39 .to_str()
40 .map_err(|_| axum::headers::Error::invalid())?;
41
42 Ok(Self(country_code.to_string()))
43 }
44
45 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
46 unimplemented!()
47 }
48}
49
50impl std::fmt::Display for CloudflareIpCountryHeader {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(f, "{}", self.0)
53 }
54}
55
56pub struct SystemIdHeader(String);
57
58impl Header for SystemIdHeader {
59 fn name() -> &'static HeaderName {
60 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
61 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
62 }
63
64 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
65 where
66 Self: Sized,
67 I: Iterator<Item = &'i axum::http::HeaderValue>,
68 {
69 let system_id = values
70 .next()
71 .ok_or_else(axum::headers::Error::invalid)?
72 .to_str()
73 .map_err(|_| axum::headers::Error::invalid())?;
74
75 Ok(Self(system_id.to_string()))
76 }
77
78 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
79 unimplemented!()
80 }
81}
82
83impl std::fmt::Display for SystemIdHeader {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "{}", self.0)
86 }
87}
88
89pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
90 Router::new()
91 .route("/users/:id/access_tokens", post(create_access_token))
92 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
93 .layer(
94 ServiceBuilder::new()
95 .layer(Extension(rpc_server))
96 .layer(middleware::from_fn(validate_api_token)),
97 )
98}
99
100pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
101 let token = req
102 .headers()
103 .get(http::header::AUTHORIZATION)
104 .and_then(|header| header.to_str().ok())
105 .ok_or_else(|| {
106 Error::http(
107 StatusCode::BAD_REQUEST,
108 "missing authorization header".to_string(),
109 )
110 })?
111 .strip_prefix("token ")
112 .ok_or_else(|| {
113 Error::http(
114 StatusCode::BAD_REQUEST,
115 "invalid authorization header".to_string(),
116 )
117 })?;
118
119 let state = req.extensions().get::<Arc<AppState>>().unwrap();
120
121 if token != state.config.api_token {
122 Err(Error::http(
123 StatusCode::UNAUTHORIZED,
124 "invalid authorization token".to_string(),
125 ))?
126 }
127
128 Ok::<_, Error>(next.run(req).await)
129}
130
131async fn get_rpc_server_snapshot(
132 Extension(rpc_server): Extension<Arc<rpc::Server>>,
133) -> Result<ErasedJson> {
134 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
135}
136
137#[derive(Deserialize)]
138struct CreateAccessTokenQueryParams {
139 public_key: String,
140 impersonate: Option<String>,
141}
142
143#[derive(Serialize)]
144struct CreateAccessTokenResponse {
145 user_id: UserId,
146 encrypted_access_token: String,
147}
148
149async fn create_access_token(
150 Path(user_id): Path<UserId>,
151 Query(params): Query<CreateAccessTokenQueryParams>,
152 Extension(app): Extension<Arc<AppState>>,
153) -> Result<Json<CreateAccessTokenResponse>> {
154 let user = app
155 .db
156 .get_user_by_id(user_id)
157 .await?
158 .context("user not found")?;
159
160 let mut impersonated_user_id = None;
161 if let Some(impersonate) = params.impersonate {
162 if user.admin {
163 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
164 impersonated_user_id = Some(impersonated_user.id);
165 } else {
166 return Err(Error::http(
167 StatusCode::UNPROCESSABLE_ENTITY,
168 format!("user {impersonate} does not exist"),
169 ));
170 }
171 } else {
172 return Err(Error::http(
173 StatusCode::UNAUTHORIZED,
174 "you do not have permission to impersonate other users".to_string(),
175 ));
176 }
177 }
178
179 let access_token =
180 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
181 let encrypted_access_token =
182 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
183
184 Ok(Json(CreateAccessTokenResponse {
185 user_id: impersonated_user_id.unwrap_or(user_id),
186 encrypted_access_token,
187 }))
188}