api.rs

  1pub mod billing;
  2pub mod contributors;
  3pub mod events;
  4pub mod extensions;
  5pub mod ips_file;
  6pub mod slack;
  7
  8use crate::api::events::SnowflakeRow;
  9use crate::{
 10    auth,
 11    db::{User, UserId},
 12    rpc, AppState, Error, Result,
 13};
 14use anyhow::anyhow;
 15use axum::{
 16    body::Body,
 17    extract::{Path, Query},
 18    headers::Header,
 19    http::{self, HeaderName, Request, StatusCode},
 20    middleware::{self, Next},
 21    response::IntoResponse,
 22    routing::{get, post},
 23    Extension, Json, Router,
 24};
 25use axum_extra::response::ErasedJson;
 26use serde::{Deserialize, Serialize};
 27use std::sync::{Arc, OnceLock};
 28use tower::ServiceBuilder;
 29
 30pub use extensions::fetch_extensions_from_blob_store_periodically;
 31
 32pub struct CloudflareIpCountryHeader(String);
 33
 34impl Header for CloudflareIpCountryHeader {
 35    fn name() -> &'static HeaderName {
 36        static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
 37        CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
 38    }
 39
 40    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 41    where
 42        Self: Sized,
 43        I: Iterator<Item = &'i axum::http::HeaderValue>,
 44    {
 45        let country_code = values
 46            .next()
 47            .ok_or_else(axum::headers::Error::invalid)?
 48            .to_str()
 49            .map_err(|_| axum::headers::Error::invalid())?;
 50
 51        Ok(Self(country_code.to_string()))
 52    }
 53
 54    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 55        unimplemented!()
 56    }
 57}
 58
 59impl std::fmt::Display for CloudflareIpCountryHeader {
 60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 61        write!(f, "{}", self.0)
 62    }
 63}
 64
 65pub struct SystemIdHeader(String);
 66
 67impl Header for SystemIdHeader {
 68    fn name() -> &'static HeaderName {
 69        static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
 70        SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
 71    }
 72
 73    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 74    where
 75        Self: Sized,
 76        I: Iterator<Item = &'i axum::http::HeaderValue>,
 77    {
 78        let system_id = values
 79            .next()
 80            .ok_or_else(axum::headers::Error::invalid)?
 81            .to_str()
 82            .map_err(|_| axum::headers::Error::invalid())?;
 83
 84        Ok(Self(system_id.to_string()))
 85    }
 86
 87    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 88        unimplemented!()
 89    }
 90}
 91
 92impl std::fmt::Display for SystemIdHeader {
 93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 94        write!(f, "{}", self.0)
 95    }
 96}
 97
 98pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
 99    Router::new()
100        .route("/user", get(get_authenticated_user))
101        .route("/users/:id/access_tokens", post(create_access_token))
102        .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
103        .route("/snowflake/events", post(write_snowflake_event))
104        .merge(billing::router())
105        .merge(contributors::router())
106        .layer(
107            ServiceBuilder::new()
108                .layer(Extension(rpc_server))
109                .layer(middleware::from_fn(validate_api_token)),
110        )
111}
112
113pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
114    let token = req
115        .headers()
116        .get(http::header::AUTHORIZATION)
117        .and_then(|header| header.to_str().ok())
118        .ok_or_else(|| {
119            Error::http(
120                StatusCode::BAD_REQUEST,
121                "missing authorization header".to_string(),
122            )
123        })?
124        .strip_prefix("token ")
125        .ok_or_else(|| {
126            Error::http(
127                StatusCode::BAD_REQUEST,
128                "invalid authorization header".to_string(),
129            )
130        })?;
131
132    let state = req.extensions().get::<Arc<AppState>>().unwrap();
133
134    if token != state.config.api_token {
135        Err(Error::http(
136            StatusCode::UNAUTHORIZED,
137            "invalid authorization token".to_string(),
138        ))?
139    }
140
141    Ok::<_, Error>(next.run(req).await)
142}
143
144#[derive(Debug, Deserialize)]
145struct AuthenticatedUserParams {
146    github_user_id: i32,
147    github_login: String,
148    github_email: Option<String>,
149    github_name: Option<String>,
150    github_user_created_at: chrono::DateTime<chrono::Utc>,
151}
152
153#[derive(Debug, Serialize)]
154struct AuthenticatedUserResponse {
155    user: User,
156    metrics_id: String,
157}
158
159async fn get_authenticated_user(
160    Query(params): Query<AuthenticatedUserParams>,
161    Extension(app): Extension<Arc<AppState>>,
162) -> Result<Json<AuthenticatedUserResponse>> {
163    let initial_channel_id = app.config.auto_join_channel_id;
164
165    let user = app
166        .db
167        .get_or_create_user_by_github_account(
168            &params.github_login,
169            params.github_user_id,
170            params.github_email.as_deref(),
171            params.github_name.as_deref(),
172            params.github_user_created_at,
173            initial_channel_id,
174        )
175        .await?;
176    let metrics_id = app.db.get_user_metrics_id(user.id).await?;
177    Ok(Json(AuthenticatedUserResponse { user, metrics_id }))
178}
179
180#[derive(Deserialize, Debug)]
181struct CreateUserParams {
182    github_user_id: i32,
183    github_login: String,
184    email_address: String,
185    email_confirmation_code: Option<String>,
186    #[serde(default)]
187    admin: bool,
188    #[serde(default)]
189    invite_count: i32,
190}
191
192async fn get_rpc_server_snapshot(
193    Extension(rpc_server): Extension<Arc<rpc::Server>>,
194) -> Result<ErasedJson> {
195    Ok(ErasedJson::pretty(rpc_server.snapshot().await))
196}
197
198#[derive(Deserialize)]
199struct CreateAccessTokenQueryParams {
200    public_key: String,
201    impersonate: Option<String>,
202}
203
204#[derive(Serialize)]
205struct CreateAccessTokenResponse {
206    user_id: UserId,
207    encrypted_access_token: String,
208}
209
210async fn create_access_token(
211    Path(user_id): Path<UserId>,
212    Query(params): Query<CreateAccessTokenQueryParams>,
213    Extension(app): Extension<Arc<AppState>>,
214) -> Result<Json<CreateAccessTokenResponse>> {
215    let user = app
216        .db
217        .get_user_by_id(user_id)
218        .await?
219        .ok_or_else(|| anyhow!("user not found"))?;
220
221    let mut impersonated_user_id = None;
222    if let Some(impersonate) = params.impersonate {
223        if user.admin {
224            if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
225                impersonated_user_id = Some(impersonated_user.id);
226            } else {
227                return Err(Error::http(
228                    StatusCode::UNPROCESSABLE_ENTITY,
229                    format!("user {impersonate} does not exist"),
230                ));
231            }
232        } else {
233            return Err(Error::http(
234                StatusCode::UNAUTHORIZED,
235                "you do not have permission to impersonate other users".to_string(),
236            ));
237        }
238    }
239
240    let access_token =
241        auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
242    let encrypted_access_token =
243        auth::encrypt_access_token(&access_token, params.public_key.clone())?;
244
245    Ok(Json(CreateAccessTokenResponse {
246        user_id: impersonated_user_id.unwrap_or(user_id),
247        encrypted_access_token,
248    }))
249}
250
251/// An endpoint that writes a Snowflake event to our event stream.
252///
253/// This endpoint is exposed such that other internal services can write
254/// telemetry events without needing to talk to AWS Kinesis directly.
255async fn write_snowflake_event(
256    Extension(app): Extension<Arc<AppState>>,
257    Json(event): Json<SnowflakeRow>,
258) -> Result<()> {
259    let kinesis_client = app.kinesis_client.clone();
260    let kinesis_stream = app.config.kinesis_stream.clone();
261
262    event.write(&kinesis_client, &kinesis_stream).await?;
263
264    Ok(())
265}