1pub mod contributors;
2pub mod events;
3pub mod extensions;
4
5use crate::{AppState, Error, Result, auth, db::UserId, rpc};
6use anyhow::Context as _;
7use axum::{
8 Extension, Json, Router,
9 body::Body,
10 extract::{Path, Query},
11 headers::Header,
12 http::{self, HeaderName, Request, StatusCode},
13 middleware::{self, Next},
14 response::IntoResponse,
15 routing::{get, post},
16};
17use axum_extra::response::ErasedJson;
18use serde::{Deserialize, Serialize};
19use std::sync::{Arc, OnceLock};
20use tower::ServiceBuilder;
21
22pub use extensions::fetch_extensions_from_blob_store_periodically;
23
24pub struct CloudflareIpCountryHeader(String);
25
26impl Header for CloudflareIpCountryHeader {
27 fn name() -> &'static HeaderName {
28 static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
29 CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
30 }
31
32 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
33 where
34 Self: Sized,
35 I: Iterator<Item = &'i axum::http::HeaderValue>,
36 {
37 let country_code = values
38 .next()
39 .ok_or_else(axum::headers::Error::invalid)?
40 .to_str()
41 .map_err(|_| axum::headers::Error::invalid())?;
42
43 Ok(Self(country_code.to_string()))
44 }
45
46 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
47 unimplemented!()
48 }
49}
50
51impl std::fmt::Display for CloudflareIpCountryHeader {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 write!(f, "{}", self.0)
54 }
55}
56
57pub struct SystemIdHeader(String);
58
59impl Header for SystemIdHeader {
60 fn name() -> &'static HeaderName {
61 static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
62 SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
63 }
64
65 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
66 where
67 Self: Sized,
68 I: Iterator<Item = &'i axum::http::HeaderValue>,
69 {
70 let system_id = values
71 .next()
72 .ok_or_else(axum::headers::Error::invalid)?
73 .to_str()
74 .map_err(|_| axum::headers::Error::invalid())?;
75
76 Ok(Self(system_id.to_string()))
77 }
78
79 fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
80 unimplemented!()
81 }
82}
83
84impl std::fmt::Display for SystemIdHeader {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}", self.0)
87 }
88}
89
90pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
91 Router::new()
92 .route("/users/:id/access_tokens", post(create_access_token))
93 .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
94 .merge(contributors::router())
95 .layer(
96 ServiceBuilder::new()
97 .layer(Extension(rpc_server))
98 .layer(middleware::from_fn(validate_api_token)),
99 )
100}
101
102pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
103 let token = req
104 .headers()
105 .get(http::header::AUTHORIZATION)
106 .and_then(|header| header.to_str().ok())
107 .ok_or_else(|| {
108 Error::http(
109 StatusCode::BAD_REQUEST,
110 "missing authorization header".to_string(),
111 )
112 })?
113 .strip_prefix("token ")
114 .ok_or_else(|| {
115 Error::http(
116 StatusCode::BAD_REQUEST,
117 "invalid authorization header".to_string(),
118 )
119 })?;
120
121 let state = req.extensions().get::<Arc<AppState>>().unwrap();
122
123 if token != state.config.api_token {
124 Err(Error::http(
125 StatusCode::UNAUTHORIZED,
126 "invalid authorization token".to_string(),
127 ))?
128 }
129
130 Ok::<_, Error>(next.run(req).await)
131}
132
133async fn get_rpc_server_snapshot(
134 Extension(rpc_server): Extension<Arc<rpc::Server>>,
135) -> Result<ErasedJson> {
136 Ok(ErasedJson::pretty(rpc_server.snapshot().await))
137}
138
139#[derive(Deserialize)]
140struct CreateAccessTokenQueryParams {
141 public_key: String,
142 impersonate: Option<String>,
143}
144
145#[derive(Serialize)]
146struct CreateAccessTokenResponse {
147 user_id: UserId,
148 encrypted_access_token: String,
149}
150
151async fn create_access_token(
152 Path(user_id): Path<UserId>,
153 Query(params): Query<CreateAccessTokenQueryParams>,
154 Extension(app): Extension<Arc<AppState>>,
155) -> Result<Json<CreateAccessTokenResponse>> {
156 let user = app
157 .db
158 .get_user_by_id(user_id)
159 .await?
160 .context("user not found")?;
161
162 let mut impersonated_user_id = None;
163 if let Some(impersonate) = params.impersonate {
164 if user.admin {
165 if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
166 impersonated_user_id = Some(impersonated_user.id);
167 } else {
168 return Err(Error::http(
169 StatusCode::UNPROCESSABLE_ENTITY,
170 format!("user {impersonate} does not exist"),
171 ));
172 }
173 } else {
174 return Err(Error::http(
175 StatusCode::UNAUTHORIZED,
176 "you do not have permission to impersonate other users".to_string(),
177 ));
178 }
179 }
180
181 let access_token =
182 auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
183 let encrypted_access_token =
184 auth::encrypt_access_token(&access_token, params.public_key.clone())?;
185
186 Ok(Json(CreateAccessTokenResponse {
187 user_id: impersonated_user_id.unwrap_or(user_id),
188 encrypted_access_token,
189 }))
190}