billing.rs

   1use anyhow::{Context, anyhow, bail};
   2use axum::{
   3    Extension, Json, Router,
   4    extract::{self, Query},
   5    routing::{get, post},
   6};
   7use chrono::{DateTime, SecondsFormat, Utc};
   8use collections::HashSet;
   9use reqwest::StatusCode;
  10use sea_orm::ActiveValue;
  11use serde::{Deserialize, Serialize};
  12use serde_json::json;
  13use std::{str::FromStr, sync::Arc, time::Duration};
  14use stripe::{
  15    BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
  16    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
  17    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
  18    CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
  19    EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
  20};
  21use util::ResultExt;
  22
  23use crate::api::events::SnowflakeRow;
  24use crate::db::billing_subscription::{
  25    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  26};
  27use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
  28use crate::rpc::{ResultExt as _, Server};
  29use crate::{AppState, Cents, Error, Result};
  30use crate::{db::UserId, llm::db::LlmDatabase};
  31use crate::{
  32    db::{
  33        BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
  34        CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
  35        UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
  36    },
  37    stripe_billing::StripeBilling,
  38};
  39
  40pub fn router() -> Router {
  41    Router::new()
  42        .route(
  43            "/billing/preferences",
  44            get(get_billing_preferences).put(update_billing_preferences),
  45        )
  46        .route(
  47            "/billing/subscriptions",
  48            get(list_billing_subscriptions).post(create_billing_subscription),
  49        )
  50        .route(
  51            "/billing/subscriptions/manage",
  52            post(manage_billing_subscription),
  53        )
  54        .route("/billing/monthly_spend", get(get_monthly_spend))
  55}
  56
  57#[derive(Debug, Deserialize)]
  58struct GetBillingPreferencesParams {
  59    github_user_id: i32,
  60}
  61
  62#[derive(Debug, Serialize)]
  63struct BillingPreferencesResponse {
  64    max_monthly_llm_usage_spending_in_cents: i32,
  65}
  66
  67async fn get_billing_preferences(
  68    Extension(app): Extension<Arc<AppState>>,
  69    Query(params): Query<GetBillingPreferencesParams>,
  70) -> Result<Json<BillingPreferencesResponse>> {
  71    let user = app
  72        .db
  73        .get_user_by_github_user_id(params.github_user_id)
  74        .await?
  75        .ok_or_else(|| anyhow!("user not found"))?;
  76
  77    let preferences = app.db.get_billing_preferences(user.id).await?;
  78
  79    Ok(Json(BillingPreferencesResponse {
  80        max_monthly_llm_usage_spending_in_cents: preferences
  81            .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
  82                preferences.max_monthly_llm_usage_spending_in_cents
  83            }),
  84    }))
  85}
  86
  87#[derive(Debug, Deserialize)]
  88struct UpdateBillingPreferencesBody {
  89    github_user_id: i32,
  90    max_monthly_llm_usage_spending_in_cents: i32,
  91}
  92
  93async fn update_billing_preferences(
  94    Extension(app): Extension<Arc<AppState>>,
  95    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
  96    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
  97) -> Result<Json<BillingPreferencesResponse>> {
  98    let user = app
  99        .db
 100        .get_user_by_github_user_id(body.github_user_id)
 101        .await?
 102        .ok_or_else(|| anyhow!("user not found"))?;
 103
 104    let max_monthly_llm_usage_spending_in_cents =
 105        body.max_monthly_llm_usage_spending_in_cents.max(0);
 106
 107    let billing_preferences =
 108        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
 109            app.db
 110                .update_billing_preferences(
 111                    user.id,
 112                    &UpdateBillingPreferencesParams {
 113                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
 114                            max_monthly_llm_usage_spending_in_cents,
 115                        ),
 116                    },
 117                )
 118                .await?
 119        } else {
 120            app.db
 121                .create_billing_preferences(
 122                    user.id,
 123                    &crate::db::CreateBillingPreferencesParams {
 124                        max_monthly_llm_usage_spending_in_cents,
 125                    },
 126                )
 127                .await?
 128        };
 129
 130    SnowflakeRow::new(
 131        "Spend Limit Updated",
 132        Some(user.metrics_id),
 133        user.admin,
 134        None,
 135        json!({
 136            "user_id": user.id,
 137            "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
 138        }),
 139    )
 140    .write(&app.kinesis_client, &app.config.kinesis_stream)
 141    .await
 142    .log_err();
 143
 144    rpc_server.refresh_llm_tokens_for_user(user.id).await;
 145
 146    Ok(Json(BillingPreferencesResponse {
 147        max_monthly_llm_usage_spending_in_cents: billing_preferences
 148            .max_monthly_llm_usage_spending_in_cents,
 149    }))
 150}
 151
 152#[derive(Debug, Deserialize)]
 153struct ListBillingSubscriptionsParams {
 154    github_user_id: i32,
 155}
 156
 157#[derive(Debug, Serialize)]
 158struct BillingSubscriptionJson {
 159    id: BillingSubscriptionId,
 160    name: String,
 161    status: StripeSubscriptionStatus,
 162    cancel_at: Option<String>,
 163    /// Whether this subscription can be canceled.
 164    is_cancelable: bool,
 165}
 166
 167#[derive(Debug, Serialize)]
 168struct ListBillingSubscriptionsResponse {
 169    subscriptions: Vec<BillingSubscriptionJson>,
 170}
 171
 172async fn list_billing_subscriptions(
 173    Extension(app): Extension<Arc<AppState>>,
 174    Query(params): Query<ListBillingSubscriptionsParams>,
 175) -> Result<Json<ListBillingSubscriptionsResponse>> {
 176    let user = app
 177        .db
 178        .get_user_by_github_user_id(params.github_user_id)
 179        .await?
 180        .ok_or_else(|| anyhow!("user not found"))?;
 181
 182    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 183
 184    Ok(Json(ListBillingSubscriptionsResponse {
 185        subscriptions: subscriptions
 186            .into_iter()
 187            .map(|subscription| BillingSubscriptionJson {
 188                id: subscription.id,
 189                name: match subscription.kind {
 190                    Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
 191                    None => "Zed LLM Usage".to_string(),
 192                },
 193                status: subscription.stripe_subscription_status,
 194                cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
 195                    cancel_at
 196                        .and_utc()
 197                        .to_rfc3339_opts(SecondsFormat::Millis, true)
 198                }),
 199                is_cancelable: subscription.stripe_subscription_status.is_cancelable()
 200                    && subscription.stripe_cancel_at.is_none(),
 201            })
 202            .collect(),
 203    }))
 204}
 205
 206#[derive(Debug, Clone, Copy, Deserialize)]
 207#[serde(rename_all = "snake_case")]
 208enum ProductCode {
 209    ZedPro,
 210}
 211
 212#[derive(Debug, Deserialize)]
 213struct CreateBillingSubscriptionBody {
 214    github_user_id: i32,
 215    product: Option<ProductCode>,
 216}
 217
 218#[derive(Debug, Serialize)]
 219struct CreateBillingSubscriptionResponse {
 220    checkout_session_url: String,
 221}
 222
 223/// Initiates a Stripe Checkout session for creating a billing subscription.
 224async fn create_billing_subscription(
 225    Extension(app): Extension<Arc<AppState>>,
 226    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 227) -> Result<Json<CreateBillingSubscriptionResponse>> {
 228    let user = app
 229        .db
 230        .get_user_by_github_user_id(body.github_user_id)
 231        .await?
 232        .ok_or_else(|| anyhow!("user not found"))?;
 233
 234    let Some(stripe_client) = app.stripe_client.clone() else {
 235        log::error!("failed to retrieve Stripe client");
 236        Err(Error::http(
 237            StatusCode::NOT_IMPLEMENTED,
 238            "not supported".into(),
 239        ))?
 240    };
 241    let Some(stripe_billing) = app.stripe_billing.clone() else {
 242        log::error!("failed to retrieve Stripe billing object");
 243        Err(Error::http(
 244            StatusCode::NOT_IMPLEMENTED,
 245            "not supported".into(),
 246        ))?
 247    };
 248    let Some(llm_db) = app.llm_db.clone() else {
 249        log::error!("failed to retrieve LLM database");
 250        Err(Error::http(
 251            StatusCode::NOT_IMPLEMENTED,
 252            "not supported".into(),
 253        ))?
 254    };
 255
 256    if app.db.has_active_billing_subscription(user.id).await? {
 257        return Err(Error::http(
 258            StatusCode::CONFLICT,
 259            "user already has an active subscription".into(),
 260        ));
 261    }
 262
 263    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 264    if let Some(existing_billing_customer) = &existing_billing_customer {
 265        if existing_billing_customer.has_overdue_invoices {
 266            return Err(Error::http(
 267                StatusCode::PAYMENT_REQUIRED,
 268                "user has overdue invoices".into(),
 269            ));
 270        }
 271    }
 272
 273    let customer_id = if let Some(existing_customer) = existing_billing_customer {
 274        CustomerId::from_str(&existing_customer.stripe_customer_id)
 275            .context("failed to parse customer ID")?
 276    } else {
 277        let customer = Customer::create(
 278            &stripe_client,
 279            CreateCustomer {
 280                email: user.email_address.as_deref(),
 281                ..Default::default()
 282            },
 283        )
 284        .await?;
 285
 286        customer.id
 287    };
 288
 289    let checkout_session_url = match body.product {
 290        Some(ProductCode::ZedPro) => {
 291            let success_url = format!(
 292                "{}/account?checkout_complete=1",
 293                app.config.zed_dot_dev_url()
 294            );
 295            stripe_billing
 296                .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
 297                .await?
 298        }
 299        None => {
 300            let default_model =
 301                llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
 302            let stripe_model = stripe_billing.register_model(default_model).await?;
 303            let success_url = format!(
 304                "{}/account?checkout_complete=1",
 305                app.config.zed_dot_dev_url()
 306            );
 307            stripe_billing
 308                .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
 309                .await?
 310        }
 311    };
 312
 313    Ok(Json(CreateBillingSubscriptionResponse {
 314        checkout_session_url,
 315    }))
 316}
 317
 318#[derive(Debug, PartialEq, Deserialize)]
 319#[serde(rename_all = "snake_case")]
 320enum ManageSubscriptionIntent {
 321    /// The user intends to manage their subscription.
 322    ///
 323    /// This will open the Stripe billing portal without putting the user in a specific flow.
 324    ManageSubscription,
 325    /// The user intends to cancel their subscription.
 326    Cancel,
 327    /// The user intends to stop the cancellation of their subscription.
 328    StopCancellation,
 329}
 330
 331#[derive(Debug, Deserialize)]
 332struct ManageBillingSubscriptionBody {
 333    github_user_id: i32,
 334    intent: ManageSubscriptionIntent,
 335    /// The ID of the subscription to manage.
 336    subscription_id: BillingSubscriptionId,
 337}
 338
 339#[derive(Debug, Serialize)]
 340struct ManageBillingSubscriptionResponse {
 341    billing_portal_session_url: Option<String>,
 342}
 343
 344/// Initiates a Stripe customer portal session for managing a billing subscription.
 345async fn manage_billing_subscription(
 346    Extension(app): Extension<Arc<AppState>>,
 347    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
 348) -> Result<Json<ManageBillingSubscriptionResponse>> {
 349    let user = app
 350        .db
 351        .get_user_by_github_user_id(body.github_user_id)
 352        .await?
 353        .ok_or_else(|| anyhow!("user not found"))?;
 354
 355    let Some(stripe_client) = app.stripe_client.clone() else {
 356        log::error!("failed to retrieve Stripe client");
 357        Err(Error::http(
 358            StatusCode::NOT_IMPLEMENTED,
 359            "not supported".into(),
 360        ))?
 361    };
 362
 363    let customer = app
 364        .db
 365        .get_billing_customer_by_user_id(user.id)
 366        .await?
 367        .ok_or_else(|| anyhow!("billing customer not found"))?;
 368    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
 369        .context("failed to parse customer ID")?;
 370
 371    let subscription = app
 372        .db
 373        .get_billing_subscription_by_id(body.subscription_id)
 374        .await?
 375        .ok_or_else(|| anyhow!("subscription not found"))?;
 376
 377    if body.intent == ManageSubscriptionIntent::StopCancellation {
 378        let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
 379            .context("failed to parse subscription ID")?;
 380
 381        let updated_stripe_subscription = Subscription::update(
 382            &stripe_client,
 383            &subscription_id,
 384            stripe::UpdateSubscription {
 385                cancel_at_period_end: Some(false),
 386                ..Default::default()
 387            },
 388        )
 389        .await?;
 390
 391        app.db
 392            .update_billing_subscription(
 393                subscription.id,
 394                &UpdateBillingSubscriptionParams {
 395                    stripe_cancel_at: ActiveValue::set(
 396                        updated_stripe_subscription
 397                            .cancel_at
 398                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 399                            .map(|time| time.naive_utc()),
 400                    ),
 401                    ..Default::default()
 402                },
 403            )
 404            .await?;
 405
 406        return Ok(Json(ManageBillingSubscriptionResponse {
 407            billing_portal_session_url: None,
 408        }));
 409    }
 410
 411    let flow = match body.intent {
 412        ManageSubscriptionIntent::ManageSubscription => None,
 413        ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
 414            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
 415            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 416                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 417                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 418                    return_url: format!("{}/account", app.config.zed_dot_dev_url()),
 419                }),
 420                ..Default::default()
 421            }),
 422            subscription_cancel: Some(
 423                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
 424                    subscription: subscription.stripe_subscription_id,
 425                    retention: None,
 426                },
 427            ),
 428            ..Default::default()
 429        }),
 430        ManageSubscriptionIntent::StopCancellation => unreachable!(),
 431    };
 432
 433    let mut params = CreateBillingPortalSession::new(customer_id);
 434    params.flow_data = flow;
 435    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
 436    params.return_url = Some(&return_url);
 437
 438    let session = BillingPortalSession::create(&stripe_client, params).await?;
 439
 440    Ok(Json(ManageBillingSubscriptionResponse {
 441        billing_portal_session_url: Some(session.url),
 442    }))
 443}
 444
 445/// The amount of time we wait in between each poll of Stripe events.
 446///
 447/// This value should strike a balance between:
 448///   1. Being short enough that we update quickly when something in Stripe changes
 449///   2. Being long enough that we don't eat into our rate limits.
 450///
 451/// As a point of reference, the Sequin folks say they have this at **500ms**:
 452///
 453/// > We poll the Stripe /events endpoint every 500ms per account
 454/// >
 455/// > — https://blog.sequinstream.com/events-not-webhooks/
 456const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
 457
 458/// The maximum number of events to return per page.
 459///
 460/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
 461///
 462/// > Limit can range between 1 and 100, and the default is 10.
 463const EVENTS_LIMIT_PER_PAGE: u64 = 100;
 464
 465/// The number of pages consisting entirely of already-processed events that we
 466/// will see before we stop retrieving events.
 467///
 468/// This is used to prevent over-fetching the Stripe events API for events we've
 469/// already seen and processed.
 470const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 471
 472/// Polls the Stripe events API periodically to reconcile the records in our
 473/// database with the data in Stripe.
 474pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
 475    let Some(stripe_client) = app.stripe_client.clone() else {
 476        log::warn!("failed to retrieve Stripe client");
 477        return;
 478    };
 479
 480    let executor = app.executor.clone();
 481    executor.spawn_detached({
 482        let executor = executor.clone();
 483        async move {
 484            loop {
 485                poll_stripe_events(&app, &rpc_server, &stripe_client)
 486                    .await
 487                    .log_err();
 488
 489                executor.sleep(POLL_EVENTS_INTERVAL).await;
 490            }
 491        }
 492    });
 493}
 494
 495async fn poll_stripe_events(
 496    app: &Arc<AppState>,
 497    rpc_server: &Arc<Server>,
 498    stripe_client: &stripe::Client,
 499) -> anyhow::Result<()> {
 500    fn event_type_to_string(event_type: EventType) -> String {
 501        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
 502        // so we need to unquote it.
 503        event_type.to_string().trim_matches('"').to_string()
 504    }
 505
 506    let event_types = [
 507        EventType::CustomerCreated,
 508        EventType::CustomerUpdated,
 509        EventType::CustomerSubscriptionCreated,
 510        EventType::CustomerSubscriptionUpdated,
 511        EventType::CustomerSubscriptionPaused,
 512        EventType::CustomerSubscriptionResumed,
 513        EventType::CustomerSubscriptionDeleted,
 514    ]
 515    .into_iter()
 516    .map(event_type_to_string)
 517    .collect::<Vec<_>>();
 518
 519    let mut pages_of_already_processed_events = 0;
 520    let mut unprocessed_events = Vec::new();
 521
 522    log::info!(
 523        "Stripe events: starting retrieval for {}",
 524        event_types.join(", ")
 525    );
 526    let mut params = ListEvents::new();
 527    params.types = Some(event_types.clone());
 528    params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 529
 530    let mut event_pages = stripe::Event::list(&stripe_client, &params)
 531        .await?
 532        .paginate(params);
 533
 534    loop {
 535        let processed_event_ids = {
 536            let event_ids = event_pages
 537                .page
 538                .data
 539                .iter()
 540                .map(|event| event.id.as_str())
 541                .collect::<Vec<_>>();
 542            app.db
 543                .get_processed_stripe_events_by_event_ids(&event_ids)
 544                .await?
 545                .into_iter()
 546                .map(|event| event.stripe_event_id)
 547                .collect::<Vec<_>>()
 548        };
 549
 550        let mut processed_events_in_page = 0;
 551        let events_in_page = event_pages.page.data.len();
 552        for event in &event_pages.page.data {
 553            if processed_event_ids.contains(&event.id.to_string()) {
 554                processed_events_in_page += 1;
 555                log::debug!("Stripe events: already processed '{}', skipping", event.id);
 556            } else {
 557                unprocessed_events.push(event.clone());
 558            }
 559        }
 560
 561        if processed_events_in_page == events_in_page {
 562            pages_of_already_processed_events += 1;
 563        }
 564
 565        if event_pages.page.has_more {
 566            if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
 567            {
 568                log::info!(
 569                    "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
 570                );
 571                break;
 572            } else {
 573                log::info!("Stripe events: retrieving next page");
 574                event_pages = event_pages.next(&stripe_client).await?;
 575            }
 576        } else {
 577            break;
 578        }
 579    }
 580
 581    log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
 582
 583    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
 584    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
 585
 586    for event in unprocessed_events {
 587        let event_id = event.id.clone();
 588        let processed_event_params = CreateProcessedStripeEventParams {
 589            stripe_event_id: event.id.to_string(),
 590            stripe_event_type: event_type_to_string(event.type_),
 591            stripe_event_created_timestamp: event.created,
 592        };
 593
 594        // If the event has happened too far in the past, we don't want to
 595        // process it and risk overwriting other more-recent updates.
 596        //
 597        // 1 day was chosen arbitrarily. This could be made longer or shorter.
 598        let one_day = Duration::from_secs(24 * 60 * 60);
 599        let a_day_ago = Utc::now() - one_day;
 600        if a_day_ago.timestamp() > event.created {
 601            log::info!(
 602                "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
 603                event_id
 604            );
 605            app.db
 606                .create_processed_stripe_event(&processed_event_params)
 607                .await?;
 608
 609            return Ok(());
 610        }
 611
 612        let process_result = match event.type_ {
 613            EventType::CustomerCreated | EventType::CustomerUpdated => {
 614                handle_customer_event(app, stripe_client, event).await
 615            }
 616            EventType::CustomerSubscriptionCreated
 617            | EventType::CustomerSubscriptionUpdated
 618            | EventType::CustomerSubscriptionPaused
 619            | EventType::CustomerSubscriptionResumed
 620            | EventType::CustomerSubscriptionDeleted => {
 621                handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
 622            }
 623            _ => Ok(()),
 624        };
 625
 626        if let Some(()) = process_result
 627            .with_context(|| format!("failed to process event {event_id} successfully"))
 628            .log_err()
 629        {
 630            app.db
 631                .create_processed_stripe_event(&processed_event_params)
 632                .await?;
 633        }
 634    }
 635
 636    Ok(())
 637}
 638
 639async fn handle_customer_event(
 640    app: &Arc<AppState>,
 641    _stripe_client: &stripe::Client,
 642    event: stripe::Event,
 643) -> anyhow::Result<()> {
 644    let EventObject::Customer(customer) = event.data.object else {
 645        bail!("unexpected event payload for {}", event.id);
 646    };
 647
 648    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 649
 650    let Some(email) = customer.email else {
 651        log::info!("Stripe customer has no email: skipping");
 652        return Ok(());
 653    };
 654
 655    let Some(user) = app.db.get_user_by_email(&email).await? else {
 656        log::info!("no user found for email: skipping");
 657        return Ok(());
 658    };
 659
 660    if let Some(existing_customer) = app
 661        .db
 662        .get_billing_customer_by_stripe_customer_id(&customer.id)
 663        .await?
 664    {
 665        app.db
 666            .update_billing_customer(
 667                existing_customer.id,
 668                &UpdateBillingCustomerParams {
 669                    // For now we just leave the information as-is, as it is not
 670                    // likely to change.
 671                    ..Default::default()
 672                },
 673            )
 674            .await?;
 675    } else {
 676        app.db
 677            .create_billing_customer(&CreateBillingCustomerParams {
 678                user_id: user.id,
 679                stripe_customer_id: customer.id.to_string(),
 680            })
 681            .await?;
 682    }
 683
 684    Ok(())
 685}
 686
 687async fn handle_customer_subscription_event(
 688    app: &Arc<AppState>,
 689    rpc_server: &Arc<Server>,
 690    stripe_client: &stripe::Client,
 691    event: stripe::Event,
 692) -> anyhow::Result<()> {
 693    let EventObject::Subscription(subscription) = event.data.object else {
 694        bail!("unexpected event payload for {}", event.id);
 695    };
 696
 697    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 698
 699    let subscription_kind =
 700        if let Some(zed_pro_price_id) = app.config.stripe_zed_pro_price_id.as_deref() {
 701            let has_zed_pro_price = subscription.items.data.iter().any(|item| {
 702                item.price
 703                    .as_ref()
 704                    .map_or(false, |price| price.id.as_str() == zed_pro_price_id)
 705            });
 706
 707            if has_zed_pro_price {
 708                Some(SubscriptionKind::ZedPro)
 709            } else {
 710                None
 711            }
 712        } else {
 713            None
 714        };
 715
 716    let billing_customer =
 717        find_or_create_billing_customer(app, stripe_client, subscription.customer)
 718            .await?
 719            .ok_or_else(|| anyhow!("billing customer not found"))?;
 720
 721    let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
 722        && subscription
 723            .cancellation_details
 724            .as_ref()
 725            .and_then(|details| details.reason)
 726            .map_or(false, |reason| {
 727                reason == CancellationDetailsReason::PaymentFailed
 728            });
 729
 730    if was_canceled_due_to_payment_failure {
 731        app.db
 732            .update_billing_customer(
 733                billing_customer.id,
 734                &UpdateBillingCustomerParams {
 735                    has_overdue_invoices: ActiveValue::set(true),
 736                    ..Default::default()
 737                },
 738            )
 739            .await?;
 740    }
 741
 742    if let Some(existing_subscription) = app
 743        .db
 744        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
 745        .await?
 746    {
 747        app.db
 748            .update_billing_subscription(
 749                existing_subscription.id,
 750                &UpdateBillingSubscriptionParams {
 751                    billing_customer_id: ActiveValue::set(billing_customer.id),
 752                    kind: ActiveValue::set(subscription_kind),
 753                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
 754                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
 755                    stripe_cancel_at: ActiveValue::set(
 756                        subscription
 757                            .cancel_at
 758                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 759                            .map(|time| time.naive_utc()),
 760                    ),
 761                    stripe_cancellation_reason: ActiveValue::set(
 762                        subscription
 763                            .cancellation_details
 764                            .and_then(|details| details.reason)
 765                            .map(|reason| reason.into()),
 766                    ),
 767                    stripe_current_period_start: ActiveValue::set(Some(
 768                        subscription.current_period_start,
 769                    )),
 770                    stripe_current_period_end: ActiveValue::set(Some(
 771                        subscription.current_period_end,
 772                    )),
 773                },
 774            )
 775            .await?;
 776    } else {
 777        // If the user already has an active billing subscription, ignore the
 778        // event and return an `Ok` to signal that it was processed
 779        // successfully.
 780        //
 781        // There is the possibility that this could cause us to not create a
 782        // subscription in the following scenario:
 783        //
 784        //   1. User has an active subscription A
 785        //   2. User cancels subscription A
 786        //   3. User creates a new subscription B
 787        //   4. We process the new subscription B before the cancellation of subscription A
 788        //   5. User ends up with no subscriptions
 789        //
 790        // In theory this situation shouldn't arise as we try to process the events in the order they occur.
 791        if app
 792            .db
 793            .has_active_billing_subscription(billing_customer.user_id)
 794            .await?
 795        {
 796            log::info!(
 797                "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
 798                user_id = billing_customer.user_id,
 799                subscription_id = subscription.id
 800            );
 801            return Ok(());
 802        }
 803
 804        app.db
 805            .create_billing_subscription(&CreateBillingSubscriptionParams {
 806                billing_customer_id: billing_customer.id,
 807                kind: subscription_kind,
 808                stripe_subscription_id: subscription.id.to_string(),
 809                stripe_subscription_status: subscription.status.into(),
 810                stripe_cancellation_reason: subscription
 811                    .cancellation_details
 812                    .and_then(|details| details.reason)
 813                    .map(|reason| reason.into()),
 814                stripe_current_period_start: Some(subscription.current_period_start),
 815                stripe_current_period_end: Some(subscription.current_period_end),
 816            })
 817            .await?;
 818    }
 819
 820    // When the user's subscription changes, we want to refresh their LLM tokens
 821    // to either grant/revoke access.
 822    rpc_server
 823        .refresh_llm_tokens_for_user(billing_customer.user_id)
 824        .await;
 825
 826    Ok(())
 827}
 828
 829#[derive(Debug, Deserialize)]
 830struct GetMonthlySpendParams {
 831    github_user_id: i32,
 832}
 833
 834#[derive(Debug, Serialize)]
 835struct GetMonthlySpendResponse {
 836    monthly_free_tier_spend_in_cents: u32,
 837    monthly_free_tier_allowance_in_cents: u32,
 838    monthly_spend_in_cents: u32,
 839}
 840
 841async fn get_monthly_spend(
 842    Extension(app): Extension<Arc<AppState>>,
 843    Query(params): Query<GetMonthlySpendParams>,
 844) -> Result<Json<GetMonthlySpendResponse>> {
 845    let user = app
 846        .db
 847        .get_user_by_github_user_id(params.github_user_id)
 848        .await?
 849        .ok_or_else(|| anyhow!("user not found"))?;
 850
 851    let Some(llm_db) = app.llm_db.clone() else {
 852        return Err(Error::http(
 853            StatusCode::NOT_IMPLEMENTED,
 854            "LLM database not available".into(),
 855        ));
 856    };
 857
 858    let free_tier = user
 859        .custom_llm_monthly_allowance_in_cents
 860        .map(|allowance| Cents(allowance as u32))
 861        .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
 862
 863    let spending_for_month = llm_db
 864        .get_user_spending_for_month(user.id, Utc::now())
 865        .await?;
 866
 867    let free_tier_spend = Cents::min(spending_for_month, free_tier);
 868    let monthly_spend = spending_for_month.saturating_sub(free_tier);
 869
 870    Ok(Json(GetMonthlySpendResponse {
 871        monthly_free_tier_spend_in_cents: free_tier_spend.0,
 872        monthly_free_tier_allowance_in_cents: free_tier.0,
 873        monthly_spend_in_cents: monthly_spend.0,
 874    }))
 875}
 876
 877impl From<SubscriptionStatus> for StripeSubscriptionStatus {
 878    fn from(value: SubscriptionStatus) -> Self {
 879        match value {
 880            SubscriptionStatus::Incomplete => Self::Incomplete,
 881            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
 882            SubscriptionStatus::Trialing => Self::Trialing,
 883            SubscriptionStatus::Active => Self::Active,
 884            SubscriptionStatus::PastDue => Self::PastDue,
 885            SubscriptionStatus::Canceled => Self::Canceled,
 886            SubscriptionStatus::Unpaid => Self::Unpaid,
 887            SubscriptionStatus::Paused => Self::Paused,
 888        }
 889    }
 890}
 891
 892impl From<CancellationDetailsReason> for StripeCancellationReason {
 893    fn from(value: CancellationDetailsReason) -> Self {
 894        match value {
 895            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
 896            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
 897            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
 898        }
 899    }
 900}
 901
 902/// Finds or creates a billing customer using the provided customer.
 903async fn find_or_create_billing_customer(
 904    app: &Arc<AppState>,
 905    stripe_client: &stripe::Client,
 906    customer_or_id: Expandable<Customer>,
 907) -> anyhow::Result<Option<billing_customer::Model>> {
 908    let customer_id = match &customer_or_id {
 909        Expandable::Id(id) => id,
 910        Expandable::Object(customer) => customer.id.as_ref(),
 911    };
 912
 913    // If we already have a billing customer record associated with the Stripe customer,
 914    // there's nothing more we need to do.
 915    if let Some(billing_customer) = app
 916        .db
 917        .get_billing_customer_by_stripe_customer_id(customer_id)
 918        .await?
 919    {
 920        return Ok(Some(billing_customer));
 921    }
 922
 923    // If all we have is a customer ID, resolve it to a full customer record by
 924    // hitting the Stripe API.
 925    let customer = match customer_or_id {
 926        Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
 927        Expandable::Object(customer) => *customer,
 928    };
 929
 930    let Some(email) = customer.email else {
 931        return Ok(None);
 932    };
 933
 934    let Some(user) = app.db.get_user_by_email(&email).await? else {
 935        return Ok(None);
 936    };
 937
 938    let billing_customer = app
 939        .db
 940        .create_billing_customer(&CreateBillingCustomerParams {
 941            user_id: user.id,
 942            stripe_customer_id: customer.id.to_string(),
 943        })
 944        .await?;
 945
 946    Ok(Some(billing_customer))
 947}
 948
 949const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
 950
 951pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
 952    let Some(stripe_billing) = app.stripe_billing.clone() else {
 953        log::warn!("failed to retrieve Stripe billing object");
 954        return;
 955    };
 956    let Some(llm_db) = app.llm_db.clone() else {
 957        log::warn!("failed to retrieve LLM database");
 958        return;
 959    };
 960
 961    let executor = app.executor.clone();
 962    executor.spawn_detached({
 963        let executor = executor.clone();
 964        async move {
 965            loop {
 966                sync_with_stripe(&app, &llm_db, &stripe_billing)
 967                    .await
 968                    .context("failed to sync LLM usage to Stripe")
 969                    .trace_err();
 970                executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
 971            }
 972        }
 973    });
 974}
 975
 976async fn sync_with_stripe(
 977    app: &Arc<AppState>,
 978    llm_db: &Arc<LlmDatabase>,
 979    stripe_billing: &Arc<StripeBilling>,
 980) -> anyhow::Result<()> {
 981    let events = llm_db.get_billing_events().await?;
 982    let user_ids = events
 983        .iter()
 984        .map(|(event, _)| event.user_id)
 985        .collect::<HashSet<UserId>>();
 986    let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
 987
 988    for (event, model) in events {
 989        let Some((stripe_db_customer, stripe_db_subscription)) =
 990            stripe_subscriptions.get(&event.user_id)
 991        else {
 992            tracing::warn!(
 993                user_id = event.user_id.0,
 994                "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
 995            );
 996            continue;
 997        };
 998        let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
 999            .stripe_subscription_id
1000            .parse()
1001            .context("failed to parse stripe subscription id from db")?;
1002        let stripe_customer_id: stripe::CustomerId = stripe_db_customer
1003            .stripe_customer_id
1004            .parse()
1005            .context("failed to parse stripe customer id from db")?;
1006
1007        let stripe_model = stripe_billing.register_model(&model).await?;
1008        stripe_billing
1009            .subscribe_to_model(&stripe_subscription_id, &stripe_model)
1010            .await?;
1011        stripe_billing
1012            .bill_model_usage(&stripe_customer_id, &stripe_model, &event)
1013            .await?;
1014        llm_db.consume_billing_event(event.id).await?;
1015    }
1016
1017    Ok(())
1018}