api.rs

  1pub mod events;
  2pub mod extensions;
  3
  4use crate::{AppState, Error, Result, auth, db::UserId, rpc};
  5use anyhow::Context as _;
  6use axum::{
  7    Extension, Json, Router,
  8    body::Body,
  9    extract::{Path, Query},
 10    headers::Header,
 11    http::{self, HeaderName, Request, StatusCode},
 12    middleware::{self, Next},
 13    response::IntoResponse,
 14    routing::{get, post},
 15};
 16use axum_extra::response::ErasedJson;
 17use serde::{Deserialize, Serialize};
 18use std::sync::{Arc, OnceLock};
 19use tower::ServiceBuilder;
 20
 21pub use extensions::fetch_extensions_from_blob_store_periodically;
 22
 23pub struct CloudflareIpCountryHeader(String);
 24
 25impl Header for CloudflareIpCountryHeader {
 26    fn name() -> &'static HeaderName {
 27        static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
 28        CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
 29    }
 30
 31    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 32    where
 33        Self: Sized,
 34        I: Iterator<Item = &'i axum::http::HeaderValue>,
 35    {
 36        let country_code = values
 37            .next()
 38            .ok_or_else(axum::headers::Error::invalid)?
 39            .to_str()
 40            .map_err(|_| axum::headers::Error::invalid())?;
 41
 42        Ok(Self(country_code.to_string()))
 43    }
 44
 45    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 46        unimplemented!()
 47    }
 48}
 49
 50impl std::fmt::Display for CloudflareIpCountryHeader {
 51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 52        write!(f, "{}", self.0)
 53    }
 54}
 55
 56pub struct SystemIdHeader(String);
 57
 58impl Header for SystemIdHeader {
 59    fn name() -> &'static HeaderName {
 60        static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
 61        SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
 62    }
 63
 64    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 65    where
 66        Self: Sized,
 67        I: Iterator<Item = &'i axum::http::HeaderValue>,
 68    {
 69        let system_id = values
 70            .next()
 71            .ok_or_else(axum::headers::Error::invalid)?
 72            .to_str()
 73            .map_err(|_| axum::headers::Error::invalid())?;
 74
 75        Ok(Self(system_id.to_string()))
 76    }
 77
 78    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 79        unimplemented!()
 80    }
 81}
 82
 83impl std::fmt::Display for SystemIdHeader {
 84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 85        write!(f, "{}", self.0)
 86    }
 87}
 88
 89pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
 90    Router::new()
 91        .route("/users/:id/access_tokens", post(create_access_token))
 92        .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
 93        .layer(
 94            ServiceBuilder::new()
 95                .layer(Extension(rpc_server))
 96                .layer(middleware::from_fn(validate_api_token)),
 97        )
 98}
 99
100pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
101    let token = req
102        .headers()
103        .get(http::header::AUTHORIZATION)
104        .and_then(|header| header.to_str().ok())
105        .ok_or_else(|| {
106            Error::http(
107                StatusCode::BAD_REQUEST,
108                "missing authorization header".to_string(),
109            )
110        })?
111        .strip_prefix("token ")
112        .ok_or_else(|| {
113            Error::http(
114                StatusCode::BAD_REQUEST,
115                "invalid authorization header".to_string(),
116            )
117        })?;
118
119    let state = req.extensions().get::<Arc<AppState>>().unwrap();
120
121    if token != state.config.api_token {
122        Err(Error::http(
123            StatusCode::UNAUTHORIZED,
124            "invalid authorization token".to_string(),
125        ))?
126    }
127
128    Ok::<_, Error>(next.run(req).await)
129}
130
131async fn get_rpc_server_snapshot(
132    Extension(rpc_server): Extension<Arc<rpc::Server>>,
133) -> Result<ErasedJson> {
134    Ok(ErasedJson::pretty(rpc_server.snapshot().await))
135}
136
137#[derive(Deserialize)]
138struct CreateAccessTokenQueryParams {
139    public_key: String,
140    impersonate: Option<String>,
141}
142
143#[derive(Serialize)]
144struct CreateAccessTokenResponse {
145    user_id: UserId,
146    encrypted_access_token: String,
147}
148
149async fn create_access_token(
150    Path(user_id): Path<UserId>,
151    Query(params): Query<CreateAccessTokenQueryParams>,
152    Extension(app): Extension<Arc<AppState>>,
153) -> Result<Json<CreateAccessTokenResponse>> {
154    let user = app
155        .db
156        .get_user_by_id(user_id)
157        .await?
158        .context("user not found")?;
159
160    let mut impersonated_user_id = None;
161    if let Some(impersonate) = params.impersonate {
162        if user.admin {
163            if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
164                impersonated_user_id = Some(impersonated_user.id);
165            } else {
166                return Err(Error::http(
167                    StatusCode::UNPROCESSABLE_ENTITY,
168                    format!("user {impersonate} does not exist"),
169                ));
170            }
171        } else {
172            return Err(Error::http(
173                StatusCode::UNAUTHORIZED,
174                "you do not have permission to impersonate other users".to_string(),
175            ));
176        }
177    }
178
179    let access_token =
180        auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
181    let encrypted_access_token =
182        auth::encrypt_access_token(&access_token, params.public_key.clone())?;
183
184    Ok(Json(CreateAccessTokenResponse {
185        user_id: impersonated_user_id.unwrap_or(user_id),
186        encrypted_access_token,
187    }))
188}