1pub mod contributors;
  2pub mod events;
  3pub mod extensions;
  4
  5use crate::{AppState, Error, Result, auth, db::UserId, rpc};
  6use anyhow::Context as _;
  7use axum::{
  8    Extension, Json, Router,
  9    body::Body,
 10    extract::{Path, Query},
 11    headers::Header,
 12    http::{self, HeaderName, Request, StatusCode},
 13    middleware::{self, Next},
 14    response::IntoResponse,
 15    routing::{get, post},
 16};
 17use axum_extra::response::ErasedJson;
 18use serde::{Deserialize, Serialize};
 19use std::sync::{Arc, OnceLock};
 20use tower::ServiceBuilder;
 21
 22pub use extensions::fetch_extensions_from_blob_store_periodically;
 23
 24pub struct CloudflareIpCountryHeader(String);
 25
 26impl Header for CloudflareIpCountryHeader {
 27    fn name() -> &'static HeaderName {
 28        static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
 29        CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
 30    }
 31
 32    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 33    where
 34        Self: Sized,
 35        I: Iterator<Item = &'i axum::http::HeaderValue>,
 36    {
 37        let country_code = values
 38            .next()
 39            .ok_or_else(axum::headers::Error::invalid)?
 40            .to_str()
 41            .map_err(|_| axum::headers::Error::invalid())?;
 42
 43        Ok(Self(country_code.to_string()))
 44    }
 45
 46    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 47        unimplemented!()
 48    }
 49}
 50
 51impl std::fmt::Display for CloudflareIpCountryHeader {
 52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 53        write!(f, "{}", self.0)
 54    }
 55}
 56
 57pub struct SystemIdHeader(String);
 58
 59impl Header for SystemIdHeader {
 60    fn name() -> &'static HeaderName {
 61        static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
 62        SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
 63    }
 64
 65    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 66    where
 67        Self: Sized,
 68        I: Iterator<Item = &'i axum::http::HeaderValue>,
 69    {
 70        let system_id = values
 71            .next()
 72            .ok_or_else(axum::headers::Error::invalid)?
 73            .to_str()
 74            .map_err(|_| axum::headers::Error::invalid())?;
 75
 76        Ok(Self(system_id.to_string()))
 77    }
 78
 79    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 80        unimplemented!()
 81    }
 82}
 83
 84impl std::fmt::Display for SystemIdHeader {
 85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 86        write!(f, "{}", self.0)
 87    }
 88}
 89
 90pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
 91    Router::new()
 92        .route("/users/:id/access_tokens", post(create_access_token))
 93        .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
 94        .merge(contributors::router())
 95        .layer(
 96            ServiceBuilder::new()
 97                .layer(Extension(rpc_server))
 98                .layer(middleware::from_fn(validate_api_token)),
 99        )
100}
101
102pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
103    let token = req
104        .headers()
105        .get(http::header::AUTHORIZATION)
106        .and_then(|header| header.to_str().ok())
107        .ok_or_else(|| {
108            Error::http(
109                StatusCode::BAD_REQUEST,
110                "missing authorization header".to_string(),
111            )
112        })?
113        .strip_prefix("token ")
114        .ok_or_else(|| {
115            Error::http(
116                StatusCode::BAD_REQUEST,
117                "invalid authorization header".to_string(),
118            )
119        })?;
120
121    let state = req.extensions().get::<Arc<AppState>>().unwrap();
122
123    if token != state.config.api_token {
124        Err(Error::http(
125            StatusCode::UNAUTHORIZED,
126            "invalid authorization token".to_string(),
127        ))?
128    }
129
130    Ok::<_, Error>(next.run(req).await)
131}
132
133async fn get_rpc_server_snapshot(
134    Extension(rpc_server): Extension<Arc<rpc::Server>>,
135) -> Result<ErasedJson> {
136    Ok(ErasedJson::pretty(rpc_server.snapshot().await))
137}
138
139#[derive(Deserialize)]
140struct CreateAccessTokenQueryParams {
141    public_key: String,
142    impersonate: Option<String>,
143}
144
145#[derive(Serialize)]
146struct CreateAccessTokenResponse {
147    user_id: UserId,
148    encrypted_access_token: String,
149}
150
151async fn create_access_token(
152    Path(user_id): Path<UserId>,
153    Query(params): Query<CreateAccessTokenQueryParams>,
154    Extension(app): Extension<Arc<AppState>>,
155) -> Result<Json<CreateAccessTokenResponse>> {
156    let user = app
157        .db
158        .get_user_by_id(user_id)
159        .await?
160        .context("user not found")?;
161
162    let mut impersonated_user_id = None;
163    if let Some(impersonate) = params.impersonate {
164        if user.admin {
165            if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
166                impersonated_user_id = Some(impersonated_user.id);
167            } else {
168                return Err(Error::http(
169                    StatusCode::UNPROCESSABLE_ENTITY,
170                    format!("user {impersonate} does not exist"),
171                ));
172            }
173        } else {
174            return Err(Error::http(
175                StatusCode::UNAUTHORIZED,
176                "you do not have permission to impersonate other users".to_string(),
177            ));
178        }
179    }
180
181    let access_token =
182        auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
183    let encrypted_access_token =
184        auth::encrypt_access_token(&access_token, params.public_key.clone())?;
185
186    Ok(Json(CreateAccessTokenResponse {
187        user_id: impersonated_user_id.unwrap_or(user_id),
188        encrypted_access_token,
189    }))
190}