api.rs

  1pub mod billing;
  2pub mod contributors;
  3pub mod events;
  4pub mod extensions;
  5pub mod ips_file;
  6pub mod slack;
  7
  8use crate::db::Database;
  9use crate::{
 10    AppState, Error, Result, auth,
 11    db::{User, UserId},
 12    rpc,
 13};
 14use ::rpc::proto;
 15use anyhow::Context as _;
 16use axum::extract;
 17use axum::{
 18    Extension, Json, Router,
 19    body::Body,
 20    extract::{Path, Query},
 21    headers::Header,
 22    http::{self, HeaderName, Request, StatusCode},
 23    middleware::{self, Next},
 24    response::IntoResponse,
 25    routing::{get, post},
 26};
 27use axum_extra::response::ErasedJson;
 28use chrono::{DateTime, Utc};
 29use serde::{Deserialize, Serialize};
 30use std::sync::{Arc, OnceLock};
 31use tower::ServiceBuilder;
 32
 33pub use extensions::fetch_extensions_from_blob_store_periodically;
 34
 35pub struct CloudflareIpCountryHeader(String);
 36
 37impl Header for CloudflareIpCountryHeader {
 38    fn name() -> &'static HeaderName {
 39        static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
 40        CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
 41    }
 42
 43    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 44    where
 45        Self: Sized,
 46        I: Iterator<Item = &'i axum::http::HeaderValue>,
 47    {
 48        let country_code = values
 49            .next()
 50            .ok_or_else(axum::headers::Error::invalid)?
 51            .to_str()
 52            .map_err(|_| axum::headers::Error::invalid())?;
 53
 54        Ok(Self(country_code.to_string()))
 55    }
 56
 57    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 58        unimplemented!()
 59    }
 60}
 61
 62impl std::fmt::Display for CloudflareIpCountryHeader {
 63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 64        write!(f, "{}", self.0)
 65    }
 66}
 67
 68pub struct SystemIdHeader(String);
 69
 70impl Header for SystemIdHeader {
 71    fn name() -> &'static HeaderName {
 72        static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
 73        SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
 74    }
 75
 76    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 77    where
 78        Self: Sized,
 79        I: Iterator<Item = &'i axum::http::HeaderValue>,
 80    {
 81        let system_id = values
 82            .next()
 83            .ok_or_else(axum::headers::Error::invalid)?
 84            .to_str()
 85            .map_err(|_| axum::headers::Error::invalid())?;
 86
 87        Ok(Self(system_id.to_string()))
 88    }
 89
 90    fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
 91        unimplemented!()
 92    }
 93}
 94
 95impl std::fmt::Display for SystemIdHeader {
 96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 97        write!(f, "{}", self.0)
 98    }
 99}
100
101pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
102    Router::new()
103        .route("/user", get(legacy_update_or_create_authenticated_user))
104        .route("/users/look_up", get(look_up_user))
105        .route("/users/:id/access_tokens", post(create_access_token))
106        .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
107        .route("/users/:id/update_plan", post(update_plan))
108        .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
109        .merge(contributors::router())
110        .layer(
111            ServiceBuilder::new()
112                .layer(Extension(rpc_server))
113                .layer(middleware::from_fn(validate_api_token)),
114        )
115}
116
117pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
118    let token = req
119        .headers()
120        .get(http::header::AUTHORIZATION)
121        .and_then(|header| header.to_str().ok())
122        .ok_or_else(|| {
123            Error::http(
124                StatusCode::BAD_REQUEST,
125                "missing authorization header".to_string(),
126            )
127        })?
128        .strip_prefix("token ")
129        .ok_or_else(|| {
130            Error::http(
131                StatusCode::BAD_REQUEST,
132                "invalid authorization header".to_string(),
133            )
134        })?;
135
136    let state = req.extensions().get::<Arc<AppState>>().unwrap();
137
138    if token != state.config.api_token {
139        Err(Error::http(
140            StatusCode::UNAUTHORIZED,
141            "invalid authorization token".to_string(),
142        ))?
143    }
144
145    Ok::<_, Error>(next.run(req).await)
146}
147
148#[derive(Debug, Deserialize)]
149struct AuthenticatedUserParams {
150    github_user_id: i32,
151    github_login: String,
152    github_email: Option<String>,
153    github_name: Option<String>,
154    github_user_created_at: chrono::DateTime<chrono::Utc>,
155}
156
157#[derive(Debug, Serialize)]
158struct AuthenticatedUserResponse {
159    user: User,
160    metrics_id: String,
161    feature_flags: Vec<String>,
162}
163
164/// This is a legacy endpoint that is no longer used in production.
165///
166/// It currently only exists to be used when developing Collab locally.
167async fn legacy_update_or_create_authenticated_user(
168    Query(params): Query<AuthenticatedUserParams>,
169    Extension(app): Extension<Arc<AppState>>,
170) -> Result<Json<AuthenticatedUserResponse>> {
171    let initial_channel_id = app.config.auto_join_channel_id;
172
173    let user = app
174        .db
175        .update_or_create_user_by_github_account(
176            &params.github_login,
177            params.github_user_id,
178            params.github_email.as_deref(),
179            params.github_name.as_deref(),
180            params.github_user_created_at,
181            initial_channel_id,
182        )
183        .await?;
184    let metrics_id = app.db.get_user_metrics_id(user.id).await?;
185    let feature_flags = app.db.get_user_flags(user.id).await?;
186    Ok(Json(AuthenticatedUserResponse {
187        user,
188        metrics_id,
189        feature_flags,
190    }))
191}
192
193#[derive(Debug, Deserialize)]
194struct LookUpUserParams {
195    identifier: String,
196}
197
198#[derive(Debug, Serialize)]
199struct LookUpUserResponse {
200    user: Option<User>,
201}
202
203async fn look_up_user(
204    Query(params): Query<LookUpUserParams>,
205    Extension(app): Extension<Arc<AppState>>,
206) -> Result<Json<LookUpUserResponse>> {
207    let user = resolve_identifier_to_user(&app.db, &params.identifier).await?;
208    let user = if let Some(user) = user {
209        match user {
210            UserOrId::User(user) => Some(user),
211            UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
212        }
213    } else {
214        None
215    };
216
217    Ok(Json(LookUpUserResponse { user }))
218}
219
220enum UserOrId {
221    User(User),
222    Id(UserId),
223}
224
225async fn resolve_identifier_to_user(
226    db: &Arc<Database>,
227    identifier: &str,
228) -> Result<Option<UserOrId>> {
229    if let Some(identifier) = identifier.parse::<i32>().ok() {
230        let user = db.get_user_by_id(UserId(identifier)).await?;
231
232        return Ok(user.map(UserOrId::User));
233    }
234
235    if identifier.starts_with("cus_") {
236        let billing_customer = db
237            .get_billing_customer_by_stripe_customer_id(&identifier)
238            .await?;
239
240        return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
241    }
242
243    if identifier.starts_with("sub_") {
244        let billing_subscription = db
245            .get_billing_subscription_by_stripe_subscription_id(&identifier)
246            .await?;
247
248        if let Some(billing_subscription) = billing_subscription {
249            let billing_customer = db
250                .get_billing_customer_by_id(billing_subscription.billing_customer_id)
251                .await?;
252
253            return Ok(
254                billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
255            );
256        } else {
257            return Ok(None);
258        }
259    }
260
261    if identifier.contains('@') {
262        let user = db.get_user_by_email(identifier).await?;
263
264        return Ok(user.map(UserOrId::User));
265    }
266
267    if let Some(user) = db.get_user_by_github_login(identifier).await? {
268        return Ok(Some(UserOrId::User(user)));
269    }
270
271    Ok(None)
272}
273
274#[derive(Deserialize, Debug)]
275struct CreateUserParams {
276    github_user_id: i32,
277    github_login: String,
278    email_address: String,
279    email_confirmation_code: Option<String>,
280    #[serde(default)]
281    admin: bool,
282    #[serde(default)]
283    invite_count: i32,
284}
285
286async fn get_rpc_server_snapshot(
287    Extension(rpc_server): Extension<Arc<rpc::Server>>,
288) -> Result<ErasedJson> {
289    Ok(ErasedJson::pretty(rpc_server.snapshot().await))
290}
291
292#[derive(Deserialize)]
293struct CreateAccessTokenQueryParams {
294    public_key: String,
295    impersonate: Option<String>,
296}
297
298#[derive(Serialize)]
299struct CreateAccessTokenResponse {
300    user_id: UserId,
301    encrypted_access_token: String,
302}
303
304async fn create_access_token(
305    Path(user_id): Path<UserId>,
306    Query(params): Query<CreateAccessTokenQueryParams>,
307    Extension(app): Extension<Arc<AppState>>,
308) -> Result<Json<CreateAccessTokenResponse>> {
309    let user = app
310        .db
311        .get_user_by_id(user_id)
312        .await?
313        .context("user not found")?;
314
315    let mut impersonated_user_id = None;
316    if let Some(impersonate) = params.impersonate {
317        if user.admin {
318            if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
319                impersonated_user_id = Some(impersonated_user.id);
320            } else {
321                return Err(Error::http(
322                    StatusCode::UNPROCESSABLE_ENTITY,
323                    format!("user {impersonate} does not exist"),
324                ));
325            }
326        } else {
327            return Err(Error::http(
328                StatusCode::UNAUTHORIZED,
329                "you do not have permission to impersonate other users".to_string(),
330            ));
331        }
332    }
333
334    let access_token =
335        auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
336    let encrypted_access_token =
337        auth::encrypt_access_token(&access_token, params.public_key.clone())?;
338
339    Ok(Json(CreateAccessTokenResponse {
340        user_id: impersonated_user_id.unwrap_or(user_id),
341        encrypted_access_token,
342    }))
343}
344
345#[derive(Serialize)]
346struct RefreshLlmTokensResponse {}
347
348async fn refresh_llm_tokens(
349    Path(user_id): Path<UserId>,
350    Extension(rpc_server): Extension<Arc<rpc::Server>>,
351) -> Result<Json<RefreshLlmTokensResponse>> {
352    rpc_server.refresh_llm_tokens_for_user(user_id).await;
353
354    Ok(Json(RefreshLlmTokensResponse {}))
355}
356
357#[derive(Debug, Serialize, Deserialize)]
358struct UpdatePlanBody {
359    pub plan: cloud_llm_client::Plan,
360    pub subscription_period: SubscriptionPeriod,
361    pub usage: cloud_llm_client::CurrentUsage,
362    pub trial_started_at: Option<DateTime<Utc>>,
363    pub is_usage_based_billing_enabled: bool,
364    pub is_account_too_young: bool,
365    pub has_overdue_invoices: bool,
366}
367
368#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
369struct SubscriptionPeriod {
370    pub started_at: DateTime<Utc>,
371    pub ended_at: DateTime<Utc>,
372}
373
374#[derive(Serialize)]
375struct UpdatePlanResponse {}
376
377async fn update_plan(
378    Path(user_id): Path<UserId>,
379    Extension(rpc_server): Extension<Arc<rpc::Server>>,
380    extract::Json(body): extract::Json<UpdatePlanBody>,
381) -> Result<Json<UpdatePlanResponse>> {
382    let plan = match body.plan {
383        cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
384        cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
385        cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
386    };
387
388    let update_user_plan = proto::UpdateUserPlan {
389        plan: plan.into(),
390        trial_started_at: body
391            .trial_started_at
392            .map(|trial_started_at| trial_started_at.timestamp() as u64),
393        is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
394        usage: Some(proto::SubscriptionUsage {
395            model_requests_usage_amount: body.usage.model_requests.used,
396            model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
397            edit_predictions_usage_amount: body.usage.edit_predictions.used,
398            edit_predictions_usage_limit: Some(usage_limit_to_proto(
399                body.usage.edit_predictions.limit,
400            )),
401        }),
402        subscription_period: Some(proto::SubscriptionPeriod {
403            started_at: body.subscription_period.started_at.timestamp() as u64,
404            ended_at: body.subscription_period.ended_at.timestamp() as u64,
405        }),
406        account_too_young: Some(body.is_account_too_young),
407        has_overdue_invoices: Some(body.has_overdue_invoices),
408    };
409
410    rpc_server
411        .update_plan_for_user(user_id, update_user_plan)
412        .await?;
413
414    Ok(Json(UpdatePlanResponse {}))
415}
416
417fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit {
418    proto::UsageLimit {
419        variant: Some(match limit {
420            cloud_llm_client::UsageLimit::Limited(limit) => {
421                proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
422                    limit: limit as u32,
423                })
424            }
425            cloud_llm_client::UsageLimit::Unlimited => {
426                proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
427            }
428        }),
429    }
430}