billing.rs

   1use anyhow::{Context as _, bail};
   2use axum::routing::put;
   3use axum::{Extension, Json, Router, extract, routing::post};
   4use chrono::{DateTime, SecondsFormat, Utc};
   5use collections::{HashMap, HashSet};
   6use reqwest::StatusCode;
   7use sea_orm::ActiveValue;
   8use serde::{Deserialize, Serialize};
   9use serde_json::json;
  10use std::{str::FromStr, sync::Arc, time::Duration};
  11use stripe::{
  12    BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
  13    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
  14    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
  15    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
  16    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
  17    CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
  18    PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
  19};
  20use util::{ResultExt, maybe};
  21use zed_llm_client::LanguageModelProvider;
  22
  23use crate::api::events::SnowflakeRow;
  24use crate::db::billing_subscription::{
  25    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  26};
  27use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
  28use crate::rpc::{ResultExt as _, Server};
  29use crate::stripe_client::{
  30    StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
  31    StripeSubscriptionId, UpdateCustomerParams,
  32};
  33use crate::{AppState, Error, Result};
  34use crate::{db::UserId, llm::db::LlmDatabase};
  35use crate::{
  36    db::{
  37        BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
  38        CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
  39        UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
  40    },
  41    stripe_billing::StripeBilling,
  42};
  43
  44pub fn router() -> Router {
  45    Router::new()
  46        .route("/billing/preferences", put(update_billing_preferences))
  47        .route("/billing/subscriptions", post(create_billing_subscription))
  48        .route(
  49            "/billing/subscriptions/manage",
  50            post(manage_billing_subscription),
  51        )
  52        .route(
  53            "/billing/subscriptions/sync",
  54            post(sync_billing_subscription),
  55        )
  56}
  57
  58#[derive(Debug, Serialize)]
  59struct BillingPreferencesResponse {
  60    trial_started_at: Option<String>,
  61    max_monthly_llm_usage_spending_in_cents: i32,
  62    model_request_overages_enabled: bool,
  63    model_request_overages_spend_limit_in_cents: i32,
  64}
  65
  66#[derive(Debug, Deserialize)]
  67struct UpdateBillingPreferencesBody {
  68    github_user_id: i32,
  69    #[serde(default)]
  70    max_monthly_llm_usage_spending_in_cents: i32,
  71    #[serde(default)]
  72    model_request_overages_enabled: bool,
  73    #[serde(default)]
  74    model_request_overages_spend_limit_in_cents: i32,
  75}
  76
  77async fn update_billing_preferences(
  78    Extension(app): Extension<Arc<AppState>>,
  79    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
  80    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
  81) -> Result<Json<BillingPreferencesResponse>> {
  82    let user = app
  83        .db
  84        .get_user_by_github_user_id(body.github_user_id)
  85        .await?
  86        .context("user not found")?;
  87
  88    let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
  89
  90    let max_monthly_llm_usage_spending_in_cents =
  91        body.max_monthly_llm_usage_spending_in_cents.max(0);
  92    let model_request_overages_spend_limit_in_cents =
  93        body.model_request_overages_spend_limit_in_cents.max(0);
  94
  95    let billing_preferences =
  96        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
  97            app.db
  98                .update_billing_preferences(
  99                    user.id,
 100                    &UpdateBillingPreferencesParams {
 101                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
 102                            max_monthly_llm_usage_spending_in_cents,
 103                        ),
 104                        model_request_overages_enabled: ActiveValue::set(
 105                            body.model_request_overages_enabled,
 106                        ),
 107                        model_request_overages_spend_limit_in_cents: ActiveValue::set(
 108                            model_request_overages_spend_limit_in_cents,
 109                        ),
 110                    },
 111                )
 112                .await?
 113        } else {
 114            app.db
 115                .create_billing_preferences(
 116                    user.id,
 117                    &crate::db::CreateBillingPreferencesParams {
 118                        max_monthly_llm_usage_spending_in_cents,
 119                        model_request_overages_enabled: body.model_request_overages_enabled,
 120                        model_request_overages_spend_limit_in_cents,
 121                    },
 122                )
 123                .await?
 124        };
 125
 126    SnowflakeRow::new(
 127        "Billing Preferences Updated",
 128        Some(user.metrics_id),
 129        user.admin,
 130        None,
 131        json!({
 132            "user_id": user.id,
 133            "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
 134            "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
 135            "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
 136        }),
 137    )
 138    .write(&app.kinesis_client, &app.config.kinesis_stream)
 139    .await
 140    .log_err();
 141
 142    rpc_server.refresh_llm_tokens_for_user(user.id).await;
 143
 144    Ok(Json(BillingPreferencesResponse {
 145        trial_started_at: billing_customer
 146            .and_then(|billing_customer| billing_customer.trial_started_at)
 147            .map(|trial_started_at| {
 148                trial_started_at
 149                    .and_utc()
 150                    .to_rfc3339_opts(SecondsFormat::Millis, true)
 151            }),
 152        max_monthly_llm_usage_spending_in_cents: billing_preferences
 153            .max_monthly_llm_usage_spending_in_cents,
 154        model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
 155        model_request_overages_spend_limit_in_cents: billing_preferences
 156            .model_request_overages_spend_limit_in_cents,
 157    }))
 158}
 159
 160#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
 161#[serde(rename_all = "snake_case")]
 162enum ProductCode {
 163    ZedPro,
 164    ZedProTrial,
 165}
 166
 167#[derive(Debug, Deserialize)]
 168struct CreateBillingSubscriptionBody {
 169    github_user_id: i32,
 170    product: ProductCode,
 171}
 172
 173#[derive(Debug, Serialize)]
 174struct CreateBillingSubscriptionResponse {
 175    checkout_session_url: String,
 176}
 177
 178/// Initiates a Stripe Checkout session for creating a billing subscription.
 179async fn create_billing_subscription(
 180    Extension(app): Extension<Arc<AppState>>,
 181    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 182) -> Result<Json<CreateBillingSubscriptionResponse>> {
 183    let user = app
 184        .db
 185        .get_user_by_github_user_id(body.github_user_id)
 186        .await?
 187        .context("user not found")?;
 188
 189    let Some(stripe_billing) = app.stripe_billing.clone() else {
 190        log::error!("failed to retrieve Stripe billing object");
 191        Err(Error::http(
 192            StatusCode::NOT_IMPLEMENTED,
 193            "not supported".into(),
 194        ))?
 195    };
 196
 197    if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
 198        let is_checkout_allowed = body.product == ProductCode::ZedProTrial
 199            && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
 200
 201        if !is_checkout_allowed {
 202            return Err(Error::http(
 203                StatusCode::CONFLICT,
 204                "user already has an active subscription".into(),
 205            ));
 206        }
 207    }
 208
 209    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 210    if let Some(existing_billing_customer) = &existing_billing_customer {
 211        if existing_billing_customer.has_overdue_invoices {
 212            return Err(Error::http(
 213                StatusCode::PAYMENT_REQUIRED,
 214                "user has overdue invoices".into(),
 215            ));
 216        }
 217    }
 218
 219    let customer_id = if let Some(existing_customer) = &existing_billing_customer {
 220        let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into());
 221        if let Some(email) = user.email_address.as_deref() {
 222            stripe_billing
 223                .client()
 224                .update_customer(&customer_id, UpdateCustomerParams { email: Some(email) })
 225                .await
 226                // Update of email address is best-effort - continue checkout even if it fails
 227                .context("error updating stripe customer email address")
 228                .log_err();
 229        }
 230        customer_id
 231    } else {
 232        stripe_billing
 233            .find_or_create_customer_by_email(user.email_address.as_deref())
 234            .await?
 235    };
 236
 237    let success_url = format!(
 238        "{}/account?checkout_complete=1",
 239        app.config.zed_dot_dev_url()
 240    );
 241
 242    let checkout_session_url = match body.product {
 243        ProductCode::ZedPro => {
 244            stripe_billing
 245                .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
 246                .await?
 247        }
 248        ProductCode::ZedProTrial => {
 249            if let Some(existing_billing_customer) = &existing_billing_customer {
 250                if existing_billing_customer.trial_started_at.is_some() {
 251                    return Err(Error::http(
 252                        StatusCode::FORBIDDEN,
 253                        "user already used free trial".into(),
 254                    ));
 255                }
 256            }
 257
 258            let feature_flags = app.db.get_user_flags(user.id).await?;
 259
 260            stripe_billing
 261                .checkout_with_zed_pro_trial(
 262                    &customer_id,
 263                    &user.github_login,
 264                    feature_flags,
 265                    &success_url,
 266                )
 267                .await?
 268        }
 269    };
 270
 271    Ok(Json(CreateBillingSubscriptionResponse {
 272        checkout_session_url,
 273    }))
 274}
 275
 276#[derive(Debug, PartialEq, Deserialize)]
 277#[serde(rename_all = "snake_case")]
 278enum ManageSubscriptionIntent {
 279    /// The user intends to manage their subscription.
 280    ///
 281    /// This will open the Stripe billing portal without putting the user in a specific flow.
 282    ManageSubscription,
 283    /// The user intends to update their payment method.
 284    UpdatePaymentMethod,
 285    /// The user intends to upgrade to Zed Pro.
 286    UpgradeToPro,
 287    /// The user intends to cancel their subscription.
 288    Cancel,
 289    /// The user intends to stop the cancellation of their subscription.
 290    StopCancellation,
 291}
 292
 293#[derive(Debug, Deserialize)]
 294struct ManageBillingSubscriptionBody {
 295    github_user_id: i32,
 296    intent: ManageSubscriptionIntent,
 297    /// The ID of the subscription to manage.
 298    subscription_id: BillingSubscriptionId,
 299    redirect_to: Option<String>,
 300}
 301
 302#[derive(Debug, Serialize)]
 303struct ManageBillingSubscriptionResponse {
 304    billing_portal_session_url: Option<String>,
 305}
 306
 307/// Initiates a Stripe customer portal session for managing a billing subscription.
 308async fn manage_billing_subscription(
 309    Extension(app): Extension<Arc<AppState>>,
 310    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
 311) -> Result<Json<ManageBillingSubscriptionResponse>> {
 312    let user = app
 313        .db
 314        .get_user_by_github_user_id(body.github_user_id)
 315        .await?
 316        .context("user not found")?;
 317
 318    let Some(stripe_client) = app.real_stripe_client.clone() else {
 319        log::error!("failed to retrieve Stripe client");
 320        Err(Error::http(
 321            StatusCode::NOT_IMPLEMENTED,
 322            "not supported".into(),
 323        ))?
 324    };
 325
 326    let Some(stripe_billing) = app.stripe_billing.clone() else {
 327        log::error!("failed to retrieve Stripe billing object");
 328        Err(Error::http(
 329            StatusCode::NOT_IMPLEMENTED,
 330            "not supported".into(),
 331        ))?
 332    };
 333
 334    let customer = app
 335        .db
 336        .get_billing_customer_by_user_id(user.id)
 337        .await?
 338        .context("billing customer not found")?;
 339    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
 340        .context("failed to parse customer ID")?;
 341
 342    let subscription = app
 343        .db
 344        .get_billing_subscription_by_id(body.subscription_id)
 345        .await?
 346        .context("subscription not found")?;
 347    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
 348        .context("failed to parse subscription ID")?;
 349
 350    if body.intent == ManageSubscriptionIntent::StopCancellation {
 351        let updated_stripe_subscription = Subscription::update(
 352            &stripe_client,
 353            &subscription_id,
 354            stripe::UpdateSubscription {
 355                cancel_at_period_end: Some(false),
 356                ..Default::default()
 357            },
 358        )
 359        .await?;
 360
 361        app.db
 362            .update_billing_subscription(
 363                subscription.id,
 364                &UpdateBillingSubscriptionParams {
 365                    stripe_cancel_at: ActiveValue::set(
 366                        updated_stripe_subscription
 367                            .cancel_at
 368                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 369                            .map(|time| time.naive_utc()),
 370                    ),
 371                    ..Default::default()
 372                },
 373            )
 374            .await?;
 375
 376        return Ok(Json(ManageBillingSubscriptionResponse {
 377            billing_portal_session_url: None,
 378        }));
 379    }
 380
 381    let flow = match body.intent {
 382        ManageSubscriptionIntent::ManageSubscription => None,
 383        ManageSubscriptionIntent::UpgradeToPro => {
 384            let zed_pro_price_id: stripe::PriceId =
 385                stripe_billing.zed_pro_price_id().await?.try_into()?;
 386            let zed_free_price_id: stripe::PriceId =
 387                stripe_billing.zed_free_price_id().await?.try_into()?;
 388
 389            let stripe_subscription =
 390                Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
 391
 392            let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
 393                && stripe_subscription.items.data.iter().any(|item| {
 394                    item.price
 395                        .as_ref()
 396                        .map_or(false, |price| price.id == zed_pro_price_id)
 397                });
 398            if is_on_zed_pro_trial {
 399                let payment_methods = PaymentMethod::list(
 400                    &stripe_client,
 401                    &stripe::ListPaymentMethods {
 402                        customer: Some(stripe_subscription.customer.id()),
 403                        ..Default::default()
 404                    },
 405                )
 406                .await?;
 407
 408                let has_payment_method = !payment_methods.data.is_empty();
 409                if !has_payment_method {
 410                    return Err(Error::http(
 411                        StatusCode::BAD_REQUEST,
 412                        "missing payment method".into(),
 413                    ));
 414                }
 415
 416                // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
 417                Subscription::update(
 418                    &stripe_client,
 419                    &stripe_subscription.id,
 420                    stripe::UpdateSubscription {
 421                        trial_end: Some(stripe::Scheduled::now()),
 422                        ..Default::default()
 423                    },
 424                )
 425                .await?;
 426
 427                return Ok(Json(ManageBillingSubscriptionResponse {
 428                    billing_portal_session_url: None,
 429                }));
 430            }
 431
 432            let subscription_item_to_update = stripe_subscription
 433                .items
 434                .data
 435                .iter()
 436                .find_map(|item| {
 437                    let price = item.price.as_ref()?;
 438
 439                    if price.id == zed_free_price_id {
 440                        Some(item.id.clone())
 441                    } else {
 442                        None
 443                    }
 444                })
 445                .context("No subscription item to update")?;
 446
 447            Some(CreateBillingPortalSessionFlowData {
 448                type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
 449                subscription_update_confirm: Some(
 450                    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
 451                        subscription: subscription.stripe_subscription_id,
 452                        items: vec![
 453                            CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
 454                                id: subscription_item_to_update.to_string(),
 455                                price: Some(zed_pro_price_id.to_string()),
 456                                quantity: Some(1),
 457                            },
 458                        ],
 459                        discounts: None,
 460                    },
 461                ),
 462                ..Default::default()
 463            })
 464        }
 465        ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
 466            type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
 467            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 468                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 469                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 470                    return_url: format!(
 471                        "{}{path}",
 472                        app.config.zed_dot_dev_url(),
 473                        path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
 474                    ),
 475                }),
 476                ..Default::default()
 477            }),
 478            ..Default::default()
 479        }),
 480        ManageSubscriptionIntent::Cancel => {
 481            if subscription.kind == Some(SubscriptionKind::ZedFree) {
 482                return Err(Error::http(
 483                    StatusCode::BAD_REQUEST,
 484                    "free subscription cannot be canceled".into(),
 485                ));
 486            }
 487
 488            Some(CreateBillingPortalSessionFlowData {
 489                type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
 490                after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 491                    type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 492                    redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 493                        return_url: format!("{}/account", app.config.zed_dot_dev_url()),
 494                    }),
 495                    ..Default::default()
 496                }),
 497                subscription_cancel: Some(
 498                    stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
 499                        subscription: subscription.stripe_subscription_id,
 500                        retention: None,
 501                    },
 502                ),
 503                ..Default::default()
 504            })
 505        }
 506        ManageSubscriptionIntent::StopCancellation => unreachable!(),
 507    };
 508
 509    let mut params = CreateBillingPortalSession::new(customer_id);
 510    params.flow_data = flow;
 511    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
 512    params.return_url = Some(&return_url);
 513
 514    let session = BillingPortalSession::create(&stripe_client, params).await?;
 515
 516    Ok(Json(ManageBillingSubscriptionResponse {
 517        billing_portal_session_url: Some(session.url),
 518    }))
 519}
 520
 521#[derive(Debug, Deserialize)]
 522struct SyncBillingSubscriptionBody {
 523    github_user_id: i32,
 524}
 525
 526#[derive(Debug, Serialize)]
 527struct SyncBillingSubscriptionResponse {
 528    stripe_customer_id: String,
 529}
 530
 531async fn sync_billing_subscription(
 532    Extension(app): Extension<Arc<AppState>>,
 533    extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
 534) -> Result<Json<SyncBillingSubscriptionResponse>> {
 535    let Some(stripe_client) = app.stripe_client.clone() else {
 536        log::error!("failed to retrieve Stripe client");
 537        Err(Error::http(
 538            StatusCode::NOT_IMPLEMENTED,
 539            "not supported".into(),
 540        ))?
 541    };
 542
 543    let user = app
 544        .db
 545        .get_user_by_github_user_id(body.github_user_id)
 546        .await?
 547        .context("user not found")?;
 548
 549    let billing_customer = app
 550        .db
 551        .get_billing_customer_by_user_id(user.id)
 552        .await?
 553        .context("billing customer not found")?;
 554    let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
 555
 556    let subscriptions = stripe_client
 557        .list_subscriptions_for_customer(&stripe_customer_id)
 558        .await?;
 559
 560    for subscription in subscriptions {
 561        let subscription_id = subscription.id.clone();
 562
 563        sync_subscription(&app, &stripe_client, subscription)
 564            .await
 565            .with_context(|| {
 566                format!(
 567                    "failed to sync subscription {subscription_id} for user {}",
 568                    user.id,
 569                )
 570            })?;
 571    }
 572
 573    Ok(Json(SyncBillingSubscriptionResponse {
 574        stripe_customer_id: billing_customer.stripe_customer_id.clone(),
 575    }))
 576}
 577
 578/// The amount of time we wait in between each poll of Stripe events.
 579///
 580/// This value should strike a balance between:
 581///   1. Being short enough that we update quickly when something in Stripe changes
 582///   2. Being long enough that we don't eat into our rate limits.
 583///
 584/// As a point of reference, the Sequin folks say they have this at **500ms**:
 585///
 586/// > We poll the Stripe /events endpoint every 500ms per account
 587/// >
 588/// > — https://blog.sequinstream.com/events-not-webhooks/
 589const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
 590
 591/// The maximum number of events to return per page.
 592///
 593/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
 594///
 595/// > Limit can range between 1 and 100, and the default is 10.
 596const EVENTS_LIMIT_PER_PAGE: u64 = 100;
 597
 598/// The number of pages consisting entirely of already-processed events that we
 599/// will see before we stop retrieving events.
 600///
 601/// This is used to prevent over-fetching the Stripe events API for events we've
 602/// already seen and processed.
 603const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 604
 605/// Polls the Stripe events API periodically to reconcile the records in our
 606/// database with the data in Stripe.
 607pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
 608    let Some(real_stripe_client) = app.real_stripe_client.clone() else {
 609        log::warn!("failed to retrieve Stripe client");
 610        return;
 611    };
 612    let Some(stripe_client) = app.stripe_client.clone() else {
 613        log::warn!("failed to retrieve Stripe client");
 614        return;
 615    };
 616
 617    let executor = app.executor.clone();
 618    executor.spawn_detached({
 619        let executor = executor.clone();
 620        async move {
 621            loop {
 622                poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
 623                    .await
 624                    .log_err();
 625
 626                executor.sleep(POLL_EVENTS_INTERVAL).await;
 627            }
 628        }
 629    });
 630}
 631
 632async fn poll_stripe_events(
 633    app: &Arc<AppState>,
 634    rpc_server: &Arc<Server>,
 635    stripe_client: &Arc<dyn StripeClient>,
 636    real_stripe_client: &stripe::Client,
 637) -> anyhow::Result<()> {
 638    fn event_type_to_string(event_type: EventType) -> String {
 639        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
 640        // so we need to unquote it.
 641        event_type.to_string().trim_matches('"').to_string()
 642    }
 643
 644    let event_types = [
 645        EventType::CustomerCreated,
 646        EventType::CustomerUpdated,
 647        EventType::CustomerSubscriptionCreated,
 648        EventType::CustomerSubscriptionUpdated,
 649        EventType::CustomerSubscriptionPaused,
 650        EventType::CustomerSubscriptionResumed,
 651        EventType::CustomerSubscriptionDeleted,
 652    ]
 653    .into_iter()
 654    .map(event_type_to_string)
 655    .collect::<Vec<_>>();
 656
 657    let mut pages_of_already_processed_events = 0;
 658    let mut unprocessed_events = Vec::new();
 659
 660    log::info!(
 661        "Stripe events: starting retrieval for {}",
 662        event_types.join(", ")
 663    );
 664    let mut params = ListEvents::new();
 665    params.types = Some(event_types.clone());
 666    params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 667
 668    let mut event_pages = stripe::Event::list(&real_stripe_client, &params)
 669        .await?
 670        .paginate(params);
 671
 672    loop {
 673        let processed_event_ids = {
 674            let event_ids = event_pages
 675                .page
 676                .data
 677                .iter()
 678                .map(|event| event.id.as_str())
 679                .collect::<Vec<_>>();
 680            app.db
 681                .get_processed_stripe_events_by_event_ids(&event_ids)
 682                .await?
 683                .into_iter()
 684                .map(|event| event.stripe_event_id)
 685                .collect::<Vec<_>>()
 686        };
 687
 688        let mut processed_events_in_page = 0;
 689        let events_in_page = event_pages.page.data.len();
 690        for event in &event_pages.page.data {
 691            if processed_event_ids.contains(&event.id.to_string()) {
 692                processed_events_in_page += 1;
 693                log::debug!("Stripe events: already processed '{}', skipping", event.id);
 694            } else {
 695                unprocessed_events.push(event.clone());
 696            }
 697        }
 698
 699        if processed_events_in_page == events_in_page {
 700            pages_of_already_processed_events += 1;
 701        }
 702
 703        if event_pages.page.has_more {
 704            if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
 705            {
 706                log::info!(
 707                    "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
 708                );
 709                break;
 710            } else {
 711                log::info!("Stripe events: retrieving next page");
 712                event_pages = event_pages.next(&real_stripe_client).await?;
 713            }
 714        } else {
 715            break;
 716        }
 717    }
 718
 719    log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
 720
 721    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
 722    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
 723
 724    for event in unprocessed_events {
 725        let event_id = event.id.clone();
 726        let processed_event_params = CreateProcessedStripeEventParams {
 727            stripe_event_id: event.id.to_string(),
 728            stripe_event_type: event_type_to_string(event.type_),
 729            stripe_event_created_timestamp: event.created,
 730        };
 731
 732        // If the event has happened too far in the past, we don't want to
 733        // process it and risk overwriting other more-recent updates.
 734        //
 735        // 1 day was chosen arbitrarily. This could be made longer or shorter.
 736        let one_day = Duration::from_secs(24 * 60 * 60);
 737        let a_day_ago = Utc::now() - one_day;
 738        if a_day_ago.timestamp() > event.created {
 739            log::info!(
 740                "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
 741                event_id
 742            );
 743            app.db
 744                .create_processed_stripe_event(&processed_event_params)
 745                .await?;
 746
 747            continue;
 748        }
 749
 750        let process_result = match event.type_ {
 751            EventType::CustomerCreated | EventType::CustomerUpdated => {
 752                handle_customer_event(app, real_stripe_client, event).await
 753            }
 754            EventType::CustomerSubscriptionCreated
 755            | EventType::CustomerSubscriptionUpdated
 756            | EventType::CustomerSubscriptionPaused
 757            | EventType::CustomerSubscriptionResumed
 758            | EventType::CustomerSubscriptionDeleted => {
 759                handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
 760            }
 761            _ => Ok(()),
 762        };
 763
 764        if let Some(()) = process_result
 765            .with_context(|| format!("failed to process event {event_id} successfully"))
 766            .log_err()
 767        {
 768            app.db
 769                .create_processed_stripe_event(&processed_event_params)
 770                .await?;
 771        }
 772    }
 773
 774    Ok(())
 775}
 776
 777async fn handle_customer_event(
 778    app: &Arc<AppState>,
 779    _stripe_client: &stripe::Client,
 780    event: stripe::Event,
 781) -> anyhow::Result<()> {
 782    let EventObject::Customer(customer) = event.data.object else {
 783        bail!("unexpected event payload for {}", event.id);
 784    };
 785
 786    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 787
 788    let Some(email) = customer.email else {
 789        log::info!("Stripe customer has no email: skipping");
 790        return Ok(());
 791    };
 792
 793    let Some(user) = app.db.get_user_by_email(&email).await? else {
 794        log::info!("no user found for email: skipping");
 795        return Ok(());
 796    };
 797
 798    if let Some(existing_customer) = app
 799        .db
 800        .get_billing_customer_by_stripe_customer_id(&customer.id)
 801        .await?
 802    {
 803        app.db
 804            .update_billing_customer(
 805                existing_customer.id,
 806                &UpdateBillingCustomerParams {
 807                    // For now we just leave the information as-is, as it is not
 808                    // likely to change.
 809                    ..Default::default()
 810                },
 811            )
 812            .await?;
 813    } else {
 814        app.db
 815            .create_billing_customer(&CreateBillingCustomerParams {
 816                user_id: user.id,
 817                stripe_customer_id: customer.id.to_string(),
 818            })
 819            .await?;
 820    }
 821
 822    Ok(())
 823}
 824
 825async fn sync_subscription(
 826    app: &Arc<AppState>,
 827    stripe_client: &Arc<dyn StripeClient>,
 828    subscription: StripeSubscription,
 829) -> anyhow::Result<billing_customer::Model> {
 830    let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
 831        stripe_billing
 832            .determine_subscription_kind(&subscription)
 833            .await
 834    } else {
 835        None
 836    };
 837
 838    let billing_customer =
 839        find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
 840            .await?
 841            .context("billing customer not found")?;
 842
 843    if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
 844        if subscription.status == SubscriptionStatus::Trialing {
 845            let current_period_start =
 846                DateTime::from_timestamp(subscription.current_period_start, 0)
 847                    .context("No trial subscription period start")?;
 848
 849            app.db
 850                .update_billing_customer(
 851                    billing_customer.id,
 852                    &UpdateBillingCustomerParams {
 853                        trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
 854                        ..Default::default()
 855                    },
 856                )
 857                .await?;
 858        }
 859    }
 860
 861    let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
 862        && subscription
 863            .cancellation_details
 864            .as_ref()
 865            .and_then(|details| details.reason)
 866            .map_or(false, |reason| {
 867                reason == StripeCancellationDetailsReason::PaymentFailed
 868            });
 869
 870    if was_canceled_due_to_payment_failure {
 871        app.db
 872            .update_billing_customer(
 873                billing_customer.id,
 874                &UpdateBillingCustomerParams {
 875                    has_overdue_invoices: ActiveValue::set(true),
 876                    ..Default::default()
 877                },
 878            )
 879            .await?;
 880    }
 881
 882    if let Some(existing_subscription) = app
 883        .db
 884        .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
 885        .await?
 886    {
 887        app.db
 888            .update_billing_subscription(
 889                existing_subscription.id,
 890                &UpdateBillingSubscriptionParams {
 891                    billing_customer_id: ActiveValue::set(billing_customer.id),
 892                    kind: ActiveValue::set(subscription_kind),
 893                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
 894                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
 895                    stripe_cancel_at: ActiveValue::set(
 896                        subscription
 897                            .cancel_at
 898                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 899                            .map(|time| time.naive_utc()),
 900                    ),
 901                    stripe_cancellation_reason: ActiveValue::set(
 902                        subscription
 903                            .cancellation_details
 904                            .and_then(|details| details.reason)
 905                            .map(|reason| reason.into()),
 906                    ),
 907                    stripe_current_period_start: ActiveValue::set(Some(
 908                        subscription.current_period_start,
 909                    )),
 910                    stripe_current_period_end: ActiveValue::set(Some(
 911                        subscription.current_period_end,
 912                    )),
 913                },
 914            )
 915            .await?;
 916    } else {
 917        if let Some(existing_subscription) = app
 918            .db
 919            .get_active_billing_subscription(billing_customer.user_id)
 920            .await?
 921        {
 922            if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
 923                && subscription_kind == Some(SubscriptionKind::ZedProTrial)
 924            {
 925                let stripe_subscription_id = StripeSubscriptionId(
 926                    existing_subscription.stripe_subscription_id.clone().into(),
 927                );
 928
 929                stripe_client
 930                    .cancel_subscription(&stripe_subscription_id)
 931                    .await?;
 932            } else {
 933                // If the user already has an active billing subscription, ignore the
 934                // event and return an `Ok` to signal that it was processed
 935                // successfully.
 936                //
 937                // There is the possibility that this could cause us to not create a
 938                // subscription in the following scenario:
 939                //
 940                //   1. User has an active subscription A
 941                //   2. User cancels subscription A
 942                //   3. User creates a new subscription B
 943                //   4. We process the new subscription B before the cancellation of subscription A
 944                //   5. User ends up with no subscriptions
 945                //
 946                // In theory this situation shouldn't arise as we try to process the events in the order they occur.
 947
 948                log::info!(
 949                    "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
 950                    user_id = billing_customer.user_id,
 951                    subscription_id = subscription.id
 952                );
 953                return Ok(billing_customer);
 954            }
 955        }
 956
 957        app.db
 958            .create_billing_subscription(&CreateBillingSubscriptionParams {
 959                billing_customer_id: billing_customer.id,
 960                kind: subscription_kind,
 961                stripe_subscription_id: subscription.id.to_string(),
 962                stripe_subscription_status: subscription.status.into(),
 963                stripe_cancellation_reason: subscription
 964                    .cancellation_details
 965                    .and_then(|details| details.reason)
 966                    .map(|reason| reason.into()),
 967                stripe_current_period_start: Some(subscription.current_period_start),
 968                stripe_current_period_end: Some(subscription.current_period_end),
 969            })
 970            .await?;
 971    }
 972
 973    if let Some(stripe_billing) = app.stripe_billing.as_ref() {
 974        if subscription.status == SubscriptionStatus::Canceled
 975            || subscription.status == SubscriptionStatus::Paused
 976        {
 977            let already_has_active_billing_subscription = app
 978                .db
 979                .has_active_billing_subscription(billing_customer.user_id)
 980                .await?;
 981            if !already_has_active_billing_subscription {
 982                let stripe_customer_id =
 983                    StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
 984
 985                stripe_billing
 986                    .subscribe_to_zed_free(stripe_customer_id)
 987                    .await?;
 988            }
 989        }
 990    }
 991
 992    Ok(billing_customer)
 993}
 994
 995async fn handle_customer_subscription_event(
 996    app: &Arc<AppState>,
 997    rpc_server: &Arc<Server>,
 998    stripe_client: &Arc<dyn StripeClient>,
 999    event: stripe::Event,
1000) -> anyhow::Result<()> {
1001    let EventObject::Subscription(subscription) = event.data.object else {
1002        bail!("unexpected event payload for {}", event.id);
1003    };
1004
1005    log::info!("handling Stripe {} event: {}", event.type_, event.id);
1006
1007    let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
1008
1009    // When the user's subscription changes, push down any changes to their plan.
1010    rpc_server
1011        .update_plan_for_user(billing_customer.user_id)
1012        .await
1013        .trace_err();
1014
1015    // When the user's subscription changes, we want to refresh their LLM tokens
1016    // to either grant/revoke access.
1017    rpc_server
1018        .refresh_llm_tokens_for_user(billing_customer.user_id)
1019        .await;
1020
1021    Ok(())
1022}
1023
1024impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1025    fn from(value: SubscriptionStatus) -> Self {
1026        match value {
1027            SubscriptionStatus::Incomplete => Self::Incomplete,
1028            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1029            SubscriptionStatus::Trialing => Self::Trialing,
1030            SubscriptionStatus::Active => Self::Active,
1031            SubscriptionStatus::PastDue => Self::PastDue,
1032            SubscriptionStatus::Canceled => Self::Canceled,
1033            SubscriptionStatus::Unpaid => Self::Unpaid,
1034            SubscriptionStatus::Paused => Self::Paused,
1035        }
1036    }
1037}
1038
1039impl From<CancellationDetailsReason> for StripeCancellationReason {
1040    fn from(value: CancellationDetailsReason) -> Self {
1041        match value {
1042            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1043            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1044            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1045        }
1046    }
1047}
1048
1049/// Finds or creates a billing customer using the provided customer.
1050pub async fn find_or_create_billing_customer(
1051    app: &Arc<AppState>,
1052    stripe_client: &dyn StripeClient,
1053    customer_id: &StripeCustomerId,
1054) -> anyhow::Result<Option<billing_customer::Model>> {
1055    // If we already have a billing customer record associated with the Stripe customer,
1056    // there's nothing more we need to do.
1057    if let Some(billing_customer) = app
1058        .db
1059        .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
1060        .await?
1061    {
1062        return Ok(Some(billing_customer));
1063    }
1064
1065    let customer = stripe_client.get_customer(customer_id).await?;
1066
1067    let Some(email) = customer.email else {
1068        return Ok(None);
1069    };
1070
1071    let Some(user) = app.db.get_user_by_email(&email).await? else {
1072        return Ok(None);
1073    };
1074
1075    let billing_customer = app
1076        .db
1077        .create_billing_customer(&CreateBillingCustomerParams {
1078            user_id: user.id,
1079            stripe_customer_id: customer.id.to_string(),
1080        })
1081        .await?;
1082
1083    Ok(Some(billing_customer))
1084}
1085
1086const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1087
1088pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1089    let Some(stripe_billing) = app.stripe_billing.clone() else {
1090        log::warn!("failed to retrieve Stripe billing object");
1091        return;
1092    };
1093    let Some(llm_db) = app.llm_db.clone() else {
1094        log::warn!("failed to retrieve LLM database");
1095        return;
1096    };
1097
1098    let executor = app.executor.clone();
1099    executor.spawn_detached({
1100        let executor = executor.clone();
1101        async move {
1102            loop {
1103                sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1104                    .await
1105                    .context("failed to sync LLM request usage to Stripe")
1106                    .trace_err();
1107                executor
1108                    .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1109                    .await;
1110            }
1111        }
1112    });
1113}
1114
1115async fn sync_model_request_usage_with_stripe(
1116    app: &Arc<AppState>,
1117    llm_db: &Arc<LlmDatabase>,
1118    stripe_billing: &Arc<StripeBilling>,
1119) -> anyhow::Result<()> {
1120    log::info!("Stripe usage sync: Starting");
1121    let started_at = Utc::now();
1122
1123    let staff_users = app.db.get_staff_users().await?;
1124    let staff_user_ids = staff_users
1125        .iter()
1126        .map(|user| user.id)
1127        .collect::<HashSet<UserId>>();
1128
1129    let usage_meters = llm_db
1130        .get_current_subscription_usage_meters(Utc::now())
1131        .await?;
1132    let mut usage_meters_by_user_id =
1133        HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
1134    for (usage_meter, usage) in usage_meters {
1135        let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
1136        meters.push(usage_meter);
1137    }
1138
1139    log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
1140    let get_zed_pro_subscriptions_started_at = Utc::now();
1141    let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
1142    log::info!(
1143        "Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}",
1144        billing_subscriptions.len(),
1145        Utc::now() - get_zed_pro_subscriptions_started_at
1146    );
1147
1148    let claude_sonnet_4 = stripe_billing
1149        .find_price_by_lookup_key("claude-sonnet-4-requests")
1150        .await?;
1151    let claude_sonnet_4_max = stripe_billing
1152        .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1153        .await?;
1154    let claude_opus_4 = stripe_billing
1155        .find_price_by_lookup_key("claude-opus-4-requests")
1156        .await?;
1157    let claude_opus_4_max = stripe_billing
1158        .find_price_by_lookup_key("claude-opus-4-requests-max")
1159        .await?;
1160    let claude_3_5_sonnet = stripe_billing
1161        .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1162        .await?;
1163    let claude_3_7_sonnet = stripe_billing
1164        .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1165        .await?;
1166    let claude_3_7_sonnet_max = stripe_billing
1167        .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1168        .await?;
1169
1170    let model_mode_combinations = [
1171        ("claude-opus-4", CompletionMode::Max),
1172        ("claude-opus-4", CompletionMode::Normal),
1173        ("claude-sonnet-4", CompletionMode::Max),
1174        ("claude-sonnet-4", CompletionMode::Normal),
1175        ("claude-3-7-sonnet", CompletionMode::Max),
1176        ("claude-3-7-sonnet", CompletionMode::Normal),
1177        ("claude-3-5-sonnet", CompletionMode::Normal),
1178    ];
1179
1180    let billing_subscription_count = billing_subscriptions.len();
1181
1182    log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
1183
1184    for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
1185        maybe!(async {
1186            if staff_user_ids.contains(&user_id) {
1187                return anyhow::Ok(());
1188            }
1189
1190            let stripe_customer_id =
1191                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1192            let stripe_subscription_id =
1193                StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
1194
1195            let usage_meters = usage_meters_by_user_id.get(&user_id);
1196
1197            for (model, mode) in &model_mode_combinations {
1198                let Ok(model) =
1199                    llm_db.model(LanguageModelProvider::Anthropic, model)
1200                else {
1201                    log::warn!("Failed to load model for user {user_id}: {model}");
1202                    continue;
1203                };
1204
1205                let (price, meter_event_name) = match model.name.as_str() {
1206                    "claude-opus-4" => match mode {
1207                        CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1208                        CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1209                    },
1210                    "claude-sonnet-4" => match mode {
1211                        CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1212                        CompletionMode::Max => {
1213                            (&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
1214                        }
1215                    },
1216                    "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1217                    "claude-3-7-sonnet" => match mode {
1218                        CompletionMode::Normal => {
1219                            (&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
1220                        }
1221                        CompletionMode::Max => {
1222                            (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1223                        }
1224                    },
1225                    model_name => {
1226                        bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1227                    }
1228                };
1229
1230                let model_requests = usage_meters
1231                    .and_then(|usage_meters| {
1232                        usage_meters
1233                            .iter()
1234                            .find(|meter| meter.model_id == model.id && meter.mode == *mode)
1235                    })
1236                    .map(|usage_meter| usage_meter.requests)
1237                    .unwrap_or(0);
1238
1239                if model_requests > 0 {
1240                    stripe_billing
1241                        .subscribe_to_price(&stripe_subscription_id, price)
1242                        .await?;
1243                }
1244
1245                stripe_billing
1246                    .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
1247                    .await
1248                    .with_context(|| {
1249                        format!(
1250                            "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
1251                        )
1252                    })?;
1253            }
1254
1255            Ok(())
1256        })
1257        .await
1258        .log_err();
1259    }
1260
1261    log::info!(
1262        "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
1263        Utc::now() - started_at
1264    );
1265
1266    Ok(())
1267}