api.rs

  1pub mod billing;
  2pub mod contributors;
  3pub mod events;
  4pub mod extensions;
  5pub mod ips_file;
  6pub mod slack;
  7
  8use crate::{
  9    AppState, Error, Result, auth,
 10    db::{User, UserId},
 11    rpc,
 12};
 13use anyhow::anyhow;
 14use axum::{
 15    Extension, Json, Router,
 16    body::Body,
 17    extract::{Path, Query},
 18    headers::Header,
 19    http::{self, HeaderName, Request, StatusCode},
 20    middleware::{self, Next},
 21    response::IntoResponse,
 22    routing::{get, post},
 23};
 24use axum_extra::response::ErasedJson;
 25use serde::{Deserialize, Serialize};
 26use std::sync::{Arc, OnceLock};
 27use tower::ServiceBuilder;
 28
 29pub use extensions::fetch_extensions_from_blob_store_periodically;
 30
 31pub struct CloudflareIpCountryHeader(String);
 32
 33impl Header for CloudflareIpCountryHeader {
 34    fn name() -> &'static HeaderName {
 35        static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
 36        CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
 37    }
 38
 39    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 40    where
 41        Self: Sized,
 42        I: Iterator<Item = &'i axum::http::HeaderValue>,
 43    {
 44        let country_code = values
 45            .next()
 46            .ok_or_else(axum::headers::Error::invalid)?
 47            .to_str()
 48            .map_err(|_| axum::headers::Error::invalid())?;
 49
 50        Ok(Self(country_code.to_string()))
 51    }
 52
 53    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 54        unimplemented!()
 55    }
 56}
 57
 58impl std::fmt::Display for CloudflareIpCountryHeader {
 59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 60        write!(f, "{}", self.0)
 61    }
 62}
 63
 64pub struct SystemIdHeader(String);
 65
 66impl Header for SystemIdHeader {
 67    fn name() -> &'static HeaderName {
 68        static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
 69        SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
 70    }
 71
 72    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 73    where
 74        Self: Sized,
 75        I: Iterator<Item = &'i axum::http::HeaderValue>,
 76    {
 77        let system_id = values
 78            .next()
 79            .ok_or_else(axum::headers::Error::invalid)?
 80            .to_str()
 81            .map_err(|_| axum::headers::Error::invalid())?;
 82
 83        Ok(Self(system_id.to_string()))
 84    }
 85
 86    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 87        unimplemented!()
 88    }
 89}
 90
 91impl std::fmt::Display for SystemIdHeader {
 92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 93        write!(f, "{}", self.0)
 94    }
 95}
 96
 97pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
 98    Router::new()
 99        .route("/user", get(get_authenticated_user))
100        .route("/users/:id/access_tokens", post(create_access_token))
101        .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
102        .merge(billing::router())
103        .merge(contributors::router())
104        .layer(
105            ServiceBuilder::new()
106                .layer(Extension(rpc_server))
107                .layer(middleware::from_fn(validate_api_token)),
108        )
109}
110
111pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
112    let token = req
113        .headers()
114        .get(http::header::AUTHORIZATION)
115        .and_then(|header| header.to_str().ok())
116        .ok_or_else(|| {
117            Error::http(
118                StatusCode::BAD_REQUEST,
119                "missing authorization header".to_string(),
120            )
121        })?
122        .strip_prefix("token ")
123        .ok_or_else(|| {
124            Error::http(
125                StatusCode::BAD_REQUEST,
126                "invalid authorization header".to_string(),
127            )
128        })?;
129
130    let state = req.extensions().get::<Arc<AppState>>().unwrap();
131
132    if token != state.config.api_token {
133        Err(Error::http(
134            StatusCode::UNAUTHORIZED,
135            "invalid authorization token".to_string(),
136        ))?
137    }
138
139    Ok::<_, Error>(next.run(req).await)
140}
141
142#[derive(Debug, Deserialize)]
143struct AuthenticatedUserParams {
144    github_user_id: i32,
145    github_login: String,
146    github_email: Option<String>,
147    github_name: Option<String>,
148    github_user_created_at: chrono::DateTime<chrono::Utc>,
149}
150
151#[derive(Debug, Serialize)]
152struct AuthenticatedUserResponse {
153    user: User,
154    metrics_id: String,
155    feature_flags: Vec<String>,
156}
157
158async fn get_authenticated_user(
159    Query(params): Query<AuthenticatedUserParams>,
160    Extension(app): Extension<Arc<AppState>>,
161) -> Result<Json<AuthenticatedUserResponse>> {
162    let initial_channel_id = app.config.auto_join_channel_id;
163
164    let user = app
165        .db
166        .get_or_create_user_by_github_account(
167            &params.github_login,
168            params.github_user_id,
169            params.github_email.as_deref(),
170            params.github_name.as_deref(),
171            params.github_user_created_at,
172            initial_channel_id,
173        )
174        .await?;
175    let metrics_id = app.db.get_user_metrics_id(user.id).await?;
176    let feature_flags = app.db.get_user_flags(user.id).await?;
177    Ok(Json(AuthenticatedUserResponse {
178        user,
179        metrics_id,
180        feature_flags,
181    }))
182}
183
184#[derive(Deserialize, Debug)]
185struct CreateUserParams {
186    github_user_id: i32,
187    github_login: String,
188    email_address: String,
189    email_confirmation_code: Option<String>,
190    #[serde(default)]
191    admin: bool,
192    #[serde(default)]
193    invite_count: i32,
194}
195
196async fn get_rpc_server_snapshot(
197    Extension(rpc_server): Extension<Arc<rpc::Server>>,
198) -> Result<ErasedJson> {
199    Ok(ErasedJson::pretty(rpc_server.snapshot().await))
200}
201
202#[derive(Deserialize)]
203struct CreateAccessTokenQueryParams {
204    public_key: String,
205    impersonate: Option<String>,
206}
207
208#[derive(Serialize)]
209struct CreateAccessTokenResponse {
210    user_id: UserId,
211    encrypted_access_token: String,
212}
213
214async fn create_access_token(
215    Path(user_id): Path<UserId>,
216    Query(params): Query<CreateAccessTokenQueryParams>,
217    Extension(app): Extension<Arc<AppState>>,
218) -> Result<Json<CreateAccessTokenResponse>> {
219    let user = app
220        .db
221        .get_user_by_id(user_id)
222        .await?
223        .ok_or_else(|| anyhow!("user not found"))?;
224
225    let mut impersonated_user_id = None;
226    if let Some(impersonate) = params.impersonate {
227        if user.admin {
228            if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
229                impersonated_user_id = Some(impersonated_user.id);
230            } else {
231                return Err(Error::http(
232                    StatusCode::UNPROCESSABLE_ENTITY,
233                    format!("user {impersonate} does not exist"),
234                ));
235            }
236        } else {
237            return Err(Error::http(
238                StatusCode::UNAUTHORIZED,
239                "you do not have permission to impersonate other users".to_string(),
240            ));
241        }
242    }
243
244    let access_token =
245        auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
246    let encrypted_access_token =
247        auth::encrypt_access_token(&access_token, params.public_key.clone())?;
248
249    Ok(Json(CreateAccessTokenResponse {
250        user_id: impersonated_user_id.unwrap_or(user_id),
251        encrypted_access_token,
252    }))
253}