api.rs

  1pub mod billing;
  2pub mod contributors;
  3pub mod events;
  4pub mod extensions;
  5pub mod ips_file;
  6pub mod slack;
  7
  8use crate::{
  9    auth,
 10    db::{User, UserId},
 11    rpc, AppState, Error, Result,
 12};
 13use anyhow::anyhow;
 14use axum::{
 15    body::Body,
 16    extract::{Path, Query},
 17    headers::Header,
 18    http::{self, HeaderName, Request, StatusCode},
 19    middleware::{self, Next},
 20    response::IntoResponse,
 21    routing::{get, post},
 22    Extension, Json, Router,
 23};
 24use axum_extra::response::ErasedJson;
 25use serde::{Deserialize, Serialize};
 26use std::sync::{Arc, OnceLock};
 27use tower::ServiceBuilder;
 28
 29pub use extensions::fetch_extensions_from_blob_store_periodically;
 30
 31pub struct CloudflareIpCountryHeader(String);
 32
 33impl Header for CloudflareIpCountryHeader {
 34    fn name() -> &'static HeaderName {
 35        static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
 36        CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
 37    }
 38
 39    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 40    where
 41        Self: Sized,
 42        I: Iterator<Item = &'i axum::http::HeaderValue>,
 43    {
 44        let country_code = values
 45            .next()
 46            .ok_or_else(axum::headers::Error::invalid)?
 47            .to_str()
 48            .map_err(|_| axum::headers::Error::invalid())?;
 49
 50        Ok(Self(country_code.to_string()))
 51    }
 52
 53    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 54        unimplemented!()
 55    }
 56}
 57
 58impl std::fmt::Display for CloudflareIpCountryHeader {
 59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 60        write!(f, "{}", self.0)
 61    }
 62}
 63
 64pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> {
 65    Router::new()
 66        .route("/user", get(get_authenticated_user))
 67        .route("/users/:id/access_tokens", post(create_access_token))
 68        .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
 69        .merge(billing::router())
 70        .merge(contributors::router())
 71        .layer(
 72            ServiceBuilder::new()
 73                .layer(Extension(state))
 74                .layer(Extension(rpc_server))
 75                .layer(middleware::from_fn(validate_api_token)),
 76        )
 77}
 78
 79pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
 80    let token = req
 81        .headers()
 82        .get(http::header::AUTHORIZATION)
 83        .and_then(|header| header.to_str().ok())
 84        .ok_or_else(|| {
 85            Error::Http(
 86                StatusCode::BAD_REQUEST,
 87                "missing authorization header".to_string(),
 88            )
 89        })?
 90        .strip_prefix("token ")
 91        .ok_or_else(|| {
 92            Error::Http(
 93                StatusCode::BAD_REQUEST,
 94                "invalid authorization header".to_string(),
 95            )
 96        })?;
 97
 98    let state = req.extensions().get::<Arc<AppState>>().unwrap();
 99
100    if token != state.config.api_token {
101        Err(Error::Http(
102            StatusCode::UNAUTHORIZED,
103            "invalid authorization token".to_string(),
104        ))?
105    }
106
107    Ok::<_, Error>(next.run(req).await)
108}
109
110#[derive(Debug, Deserialize)]
111struct AuthenticatedUserParams {
112    github_user_id: Option<i32>,
113    github_login: String,
114    github_email: Option<String>,
115}
116
117#[derive(Debug, Serialize)]
118struct AuthenticatedUserResponse {
119    user: User,
120    metrics_id: String,
121}
122
123async fn get_authenticated_user(
124    Query(params): Query<AuthenticatedUserParams>,
125    Extension(app): Extension<Arc<AppState>>,
126) -> Result<Json<AuthenticatedUserResponse>> {
127    let initial_channel_id = app.config.auto_join_channel_id;
128
129    let user = app
130        .db
131        .get_or_create_user_by_github_account(
132            &params.github_login,
133            params.github_user_id,
134            params.github_email.as_deref(),
135            initial_channel_id,
136        )
137        .await?;
138    let metrics_id = app.db.get_user_metrics_id(user.id).await?;
139    return Ok(Json(AuthenticatedUserResponse { user, metrics_id }));
140}
141
142#[derive(Deserialize, Debug)]
143struct CreateUserParams {
144    github_user_id: i32,
145    github_login: String,
146    email_address: String,
147    email_confirmation_code: Option<String>,
148    #[serde(default)]
149    admin: bool,
150    #[serde(default)]
151    invite_count: i32,
152}
153
154async fn get_rpc_server_snapshot(
155    Extension(rpc_server): Extension<Option<Arc<rpc::Server>>>,
156) -> Result<ErasedJson> {
157    let Some(rpc_server) = rpc_server else {
158        return Err(Error::Internal(anyhow!("rpc server is not available")));
159    };
160
161    Ok(ErasedJson::pretty(rpc_server.snapshot().await))
162}
163
164#[derive(Deserialize)]
165struct CreateAccessTokenQueryParams {
166    public_key: String,
167    impersonate: Option<String>,
168}
169
170#[derive(Serialize)]
171struct CreateAccessTokenResponse {
172    user_id: UserId,
173    encrypted_access_token: String,
174}
175
176async fn create_access_token(
177    Path(user_id): Path<UserId>,
178    Query(params): Query<CreateAccessTokenQueryParams>,
179    Extension(app): Extension<Arc<AppState>>,
180) -> Result<Json<CreateAccessTokenResponse>> {
181    let user = app
182        .db
183        .get_user_by_id(user_id)
184        .await?
185        .ok_or_else(|| anyhow!("user not found"))?;
186
187    let mut impersonated_user_id = None;
188    if let Some(impersonate) = params.impersonate {
189        if user.admin {
190            if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
191                impersonated_user_id = Some(impersonated_user.id);
192            } else {
193                return Err(Error::Http(
194                    StatusCode::UNPROCESSABLE_ENTITY,
195                    format!("user {impersonate} does not exist"),
196                ));
197            }
198        } else {
199            return Err(Error::Http(
200                StatusCode::UNAUTHORIZED,
201                "you do not have permission to impersonate other users".to_string(),
202            ));
203        }
204    }
205
206    let access_token =
207        auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
208    let encrypted_access_token =
209        auth::encrypt_access_token(&access_token, params.public_key.clone())?;
210
211    Ok(Json(CreateAccessTokenResponse {
212        user_id: impersonated_user_id.unwrap_or(user_id),
213        encrypted_access_token,
214    }))
215}