api.rs

  1pub mod contributors;
  2pub mod events;
  3pub mod extensions;
  4pub mod ips_file;
  5pub mod slack;
  6
  7use crate::{AppState, Error, Result, auth, db::UserId, rpc};
  8use anyhow::Context as _;
  9use axum::{
 10    Extension, Json, Router,
 11    body::Body,
 12    extract::{Path, Query},
 13    headers::Header,
 14    http::{self, HeaderName, Request, StatusCode},
 15    middleware::{self, Next},
 16    response::IntoResponse,
 17    routing::{get, post},
 18};
 19use axum_extra::response::ErasedJson;
 20use serde::{Deserialize, Serialize};
 21use std::sync::{Arc, OnceLock};
 22use tower::ServiceBuilder;
 23
 24pub use extensions::fetch_extensions_from_blob_store_periodically;
 25
 26pub struct CloudflareIpCountryHeader(String);
 27
 28impl Header for CloudflareIpCountryHeader {
 29    fn name() -> &'static HeaderName {
 30        static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
 31        CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
 32    }
 33
 34    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 35    where
 36        Self: Sized,
 37        I: Iterator<Item = &'i axum::http::HeaderValue>,
 38    {
 39        let country_code = values
 40            .next()
 41            .ok_or_else(axum::headers::Error::invalid)?
 42            .to_str()
 43            .map_err(|_| axum::headers::Error::invalid())?;
 44
 45        Ok(Self(country_code.to_string()))
 46    }
 47
 48    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 49        unimplemented!()
 50    }
 51}
 52
 53impl std::fmt::Display for CloudflareIpCountryHeader {
 54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 55        write!(f, "{}", self.0)
 56    }
 57}
 58
 59pub struct SystemIdHeader(String);
 60
 61impl Header for SystemIdHeader {
 62    fn name() -> &'static HeaderName {
 63        static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
 64        SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
 65    }
 66
 67    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 68    where
 69        Self: Sized,
 70        I: Iterator<Item = &'i axum::http::HeaderValue>,
 71    {
 72        let system_id = values
 73            .next()
 74            .ok_or_else(axum::headers::Error::invalid)?
 75            .to_str()
 76            .map_err(|_| axum::headers::Error::invalid())?;
 77
 78        Ok(Self(system_id.to_string()))
 79    }
 80
 81    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 82        unimplemented!()
 83    }
 84}
 85
 86impl std::fmt::Display for SystemIdHeader {
 87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 88        write!(f, "{}", self.0)
 89    }
 90}
 91
 92pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
 93    Router::new()
 94        .route("/users/:id/access_tokens", post(create_access_token))
 95        .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
 96        .merge(contributors::router())
 97        .layer(
 98            ServiceBuilder::new()
 99                .layer(Extension(rpc_server))
100                .layer(middleware::from_fn(validate_api_token)),
101        )
102}
103
104pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
105    let token = req
106        .headers()
107        .get(http::header::AUTHORIZATION)
108        .and_then(|header| header.to_str().ok())
109        .ok_or_else(|| {
110            Error::http(
111                StatusCode::BAD_REQUEST,
112                "missing authorization header".to_string(),
113            )
114        })?
115        .strip_prefix("token ")
116        .ok_or_else(|| {
117            Error::http(
118                StatusCode::BAD_REQUEST,
119                "invalid authorization header".to_string(),
120            )
121        })?;
122
123    let state = req.extensions().get::<Arc<AppState>>().unwrap();
124
125    if token != state.config.api_token {
126        Err(Error::http(
127            StatusCode::UNAUTHORIZED,
128            "invalid authorization token".to_string(),
129        ))?
130    }
131
132    Ok::<_, Error>(next.run(req).await)
133}
134
135async fn get_rpc_server_snapshot(
136    Extension(rpc_server): Extension<Arc<rpc::Server>>,
137) -> Result<ErasedJson> {
138    Ok(ErasedJson::pretty(rpc_server.snapshot().await))
139}
140
141#[derive(Deserialize)]
142struct CreateAccessTokenQueryParams {
143    public_key: String,
144    impersonate: Option<String>,
145}
146
147#[derive(Serialize)]
148struct CreateAccessTokenResponse {
149    user_id: UserId,
150    encrypted_access_token: String,
151}
152
153async fn create_access_token(
154    Path(user_id): Path<UserId>,
155    Query(params): Query<CreateAccessTokenQueryParams>,
156    Extension(app): Extension<Arc<AppState>>,
157) -> Result<Json<CreateAccessTokenResponse>> {
158    let user = app
159        .db
160        .get_user_by_id(user_id)
161        .await?
162        .context("user not found")?;
163
164    let mut impersonated_user_id = None;
165    if let Some(impersonate) = params.impersonate {
166        if user.admin {
167            if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
168                impersonated_user_id = Some(impersonated_user.id);
169            } else {
170                return Err(Error::http(
171                    StatusCode::UNPROCESSABLE_ENTITY,
172                    format!("user {impersonate} does not exist"),
173                ));
174            }
175        } else {
176            return Err(Error::http(
177                StatusCode::UNAUTHORIZED,
178                "you do not have permission to impersonate other users".to_string(),
179            ));
180        }
181    }
182
183    let access_token =
184        auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
185    let encrypted_access_token =
186        auth::encrypt_access_token(&access_token, params.public_key.clone())?;
187
188    Ok(Json(CreateAccessTokenResponse {
189        user_id: impersonated_user_id.unwrap_or(user_id),
190        encrypted_access_token,
191    }))
192}