billing.rs

   1use anyhow::{Context as _, bail};
   2use axum::{Extension, Json, Router, extract, routing::post};
   3use chrono::{DateTime, Utc};
   4use collections::{HashMap, HashSet};
   5use reqwest::StatusCode;
   6use sea_orm::ActiveValue;
   7use serde::{Deserialize, Serialize};
   8use std::{str::FromStr, sync::Arc, time::Duration};
   9use stripe::{
  10    BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
  11    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
  12    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
  13    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
  14    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
  15    CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
  16    PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
  17};
  18use util::{ResultExt, maybe};
  19use zed_llm_client::LanguageModelProvider;
  20
  21use crate::db::billing_subscription::{
  22    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  23};
  24use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
  25use crate::rpc::{ResultExt as _, Server};
  26use crate::stripe_client::{
  27    StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
  28    StripeSubscriptionId, UpdateCustomerParams,
  29};
  30use crate::{AppState, Error, Result};
  31use crate::{db::UserId, llm::db::LlmDatabase};
  32use crate::{
  33    db::{
  34        BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
  35        CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
  36        UpdateBillingSubscriptionParams, billing_customer,
  37    },
  38    stripe_billing::StripeBilling,
  39};
  40
  41pub fn router() -> Router {
  42    Router::new()
  43        .route("/billing/subscriptions", post(create_billing_subscription))
  44        .route(
  45            "/billing/subscriptions/manage",
  46            post(manage_billing_subscription),
  47        )
  48        .route(
  49            "/billing/subscriptions/sync",
  50            post(sync_billing_subscription),
  51        )
  52}
  53
  54#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
  55#[serde(rename_all = "snake_case")]
  56enum ProductCode {
  57    ZedPro,
  58    ZedProTrial,
  59}
  60
  61#[derive(Debug, Deserialize)]
  62struct CreateBillingSubscriptionBody {
  63    github_user_id: i32,
  64    product: ProductCode,
  65}
  66
  67#[derive(Debug, Serialize)]
  68struct CreateBillingSubscriptionResponse {
  69    checkout_session_url: String,
  70}
  71
  72/// Initiates a Stripe Checkout session for creating a billing subscription.
  73async fn create_billing_subscription(
  74    Extension(app): Extension<Arc<AppState>>,
  75    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
  76) -> Result<Json<CreateBillingSubscriptionResponse>> {
  77    let user = app
  78        .db
  79        .get_user_by_github_user_id(body.github_user_id)
  80        .await?
  81        .context("user not found")?;
  82
  83    let Some(stripe_billing) = app.stripe_billing.clone() else {
  84        log::error!("failed to retrieve Stripe billing object");
  85        Err(Error::http(
  86            StatusCode::NOT_IMPLEMENTED,
  87            "not supported".into(),
  88        ))?
  89    };
  90
  91    if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
  92        let is_checkout_allowed = body.product == ProductCode::ZedProTrial
  93            && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
  94
  95        if !is_checkout_allowed {
  96            return Err(Error::http(
  97                StatusCode::CONFLICT,
  98                "user already has an active subscription".into(),
  99            ));
 100        }
 101    }
 102
 103    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 104    if let Some(existing_billing_customer) = &existing_billing_customer {
 105        if existing_billing_customer.has_overdue_invoices {
 106            return Err(Error::http(
 107                StatusCode::PAYMENT_REQUIRED,
 108                "user has overdue invoices".into(),
 109            ));
 110        }
 111    }
 112
 113    let customer_id = if let Some(existing_customer) = &existing_billing_customer {
 114        let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into());
 115        if let Some(email) = user.email_address.as_deref() {
 116            stripe_billing
 117                .client()
 118                .update_customer(&customer_id, UpdateCustomerParams { email: Some(email) })
 119                .await
 120                // Update of email address is best-effort - continue checkout even if it fails
 121                .context("error updating stripe customer email address")
 122                .log_err();
 123        }
 124        customer_id
 125    } else {
 126        stripe_billing
 127            .find_or_create_customer_by_email(user.email_address.as_deref())
 128            .await?
 129    };
 130
 131    let success_url = format!(
 132        "{}/account?checkout_complete=1",
 133        app.config.zed_dot_dev_url()
 134    );
 135
 136    let checkout_session_url = match body.product {
 137        ProductCode::ZedPro => {
 138            stripe_billing
 139                .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
 140                .await?
 141        }
 142        ProductCode::ZedProTrial => {
 143            if let Some(existing_billing_customer) = &existing_billing_customer {
 144                if existing_billing_customer.trial_started_at.is_some() {
 145                    return Err(Error::http(
 146                        StatusCode::FORBIDDEN,
 147                        "user already used free trial".into(),
 148                    ));
 149                }
 150            }
 151
 152            let feature_flags = app.db.get_user_flags(user.id).await?;
 153
 154            stripe_billing
 155                .checkout_with_zed_pro_trial(
 156                    &customer_id,
 157                    &user.github_login,
 158                    feature_flags,
 159                    &success_url,
 160                )
 161                .await?
 162        }
 163    };
 164
 165    Ok(Json(CreateBillingSubscriptionResponse {
 166        checkout_session_url,
 167    }))
 168}
 169
 170#[derive(Debug, PartialEq, Deserialize)]
 171#[serde(rename_all = "snake_case")]
 172enum ManageSubscriptionIntent {
 173    /// The user intends to manage their subscription.
 174    ///
 175    /// This will open the Stripe billing portal without putting the user in a specific flow.
 176    ManageSubscription,
 177    /// The user intends to update their payment method.
 178    UpdatePaymentMethod,
 179    /// The user intends to upgrade to Zed Pro.
 180    UpgradeToPro,
 181    /// The user intends to cancel their subscription.
 182    Cancel,
 183    /// The user intends to stop the cancellation of their subscription.
 184    StopCancellation,
 185}
 186
 187#[derive(Debug, Deserialize)]
 188struct ManageBillingSubscriptionBody {
 189    github_user_id: i32,
 190    intent: ManageSubscriptionIntent,
 191    /// The ID of the subscription to manage.
 192    subscription_id: BillingSubscriptionId,
 193    redirect_to: Option<String>,
 194}
 195
 196#[derive(Debug, Serialize)]
 197struct ManageBillingSubscriptionResponse {
 198    billing_portal_session_url: Option<String>,
 199}
 200
 201/// Initiates a Stripe customer portal session for managing a billing subscription.
 202async fn manage_billing_subscription(
 203    Extension(app): Extension<Arc<AppState>>,
 204    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
 205) -> Result<Json<ManageBillingSubscriptionResponse>> {
 206    let user = app
 207        .db
 208        .get_user_by_github_user_id(body.github_user_id)
 209        .await?
 210        .context("user not found")?;
 211
 212    let Some(stripe_client) = app.real_stripe_client.clone() else {
 213        log::error!("failed to retrieve Stripe client");
 214        Err(Error::http(
 215            StatusCode::NOT_IMPLEMENTED,
 216            "not supported".into(),
 217        ))?
 218    };
 219
 220    let Some(stripe_billing) = app.stripe_billing.clone() else {
 221        log::error!("failed to retrieve Stripe billing object");
 222        Err(Error::http(
 223            StatusCode::NOT_IMPLEMENTED,
 224            "not supported".into(),
 225        ))?
 226    };
 227
 228    let customer = app
 229        .db
 230        .get_billing_customer_by_user_id(user.id)
 231        .await?
 232        .context("billing customer not found")?;
 233    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
 234        .context("failed to parse customer ID")?;
 235
 236    let subscription = app
 237        .db
 238        .get_billing_subscription_by_id(body.subscription_id)
 239        .await?
 240        .context("subscription not found")?;
 241    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
 242        .context("failed to parse subscription ID")?;
 243
 244    if body.intent == ManageSubscriptionIntent::StopCancellation {
 245        let updated_stripe_subscription = Subscription::update(
 246            &stripe_client,
 247            &subscription_id,
 248            stripe::UpdateSubscription {
 249                cancel_at_period_end: Some(false),
 250                ..Default::default()
 251            },
 252        )
 253        .await?;
 254
 255        app.db
 256            .update_billing_subscription(
 257                subscription.id,
 258                &UpdateBillingSubscriptionParams {
 259                    stripe_cancel_at: ActiveValue::set(
 260                        updated_stripe_subscription
 261                            .cancel_at
 262                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 263                            .map(|time| time.naive_utc()),
 264                    ),
 265                    ..Default::default()
 266                },
 267            )
 268            .await?;
 269
 270        return Ok(Json(ManageBillingSubscriptionResponse {
 271            billing_portal_session_url: None,
 272        }));
 273    }
 274
 275    let flow = match body.intent {
 276        ManageSubscriptionIntent::ManageSubscription => None,
 277        ManageSubscriptionIntent::UpgradeToPro => {
 278            let zed_pro_price_id: stripe::PriceId =
 279                stripe_billing.zed_pro_price_id().await?.try_into()?;
 280            let zed_free_price_id: stripe::PriceId =
 281                stripe_billing.zed_free_price_id().await?.try_into()?;
 282
 283            let stripe_subscription =
 284                Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
 285
 286            let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
 287                && stripe_subscription.items.data.iter().any(|item| {
 288                    item.price
 289                        .as_ref()
 290                        .map_or(false, |price| price.id == zed_pro_price_id)
 291                });
 292            if is_on_zed_pro_trial {
 293                let payment_methods = PaymentMethod::list(
 294                    &stripe_client,
 295                    &stripe::ListPaymentMethods {
 296                        customer: Some(stripe_subscription.customer.id()),
 297                        ..Default::default()
 298                    },
 299                )
 300                .await?;
 301
 302                let has_payment_method = !payment_methods.data.is_empty();
 303                if !has_payment_method {
 304                    return Err(Error::http(
 305                        StatusCode::BAD_REQUEST,
 306                        "missing payment method".into(),
 307                    ));
 308                }
 309
 310                // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
 311                Subscription::update(
 312                    &stripe_client,
 313                    &stripe_subscription.id,
 314                    stripe::UpdateSubscription {
 315                        trial_end: Some(stripe::Scheduled::now()),
 316                        ..Default::default()
 317                    },
 318                )
 319                .await?;
 320
 321                return Ok(Json(ManageBillingSubscriptionResponse {
 322                    billing_portal_session_url: None,
 323                }));
 324            }
 325
 326            let subscription_item_to_update = stripe_subscription
 327                .items
 328                .data
 329                .iter()
 330                .find_map(|item| {
 331                    let price = item.price.as_ref()?;
 332
 333                    if price.id == zed_free_price_id {
 334                        Some(item.id.clone())
 335                    } else {
 336                        None
 337                    }
 338                })
 339                .context("No subscription item to update")?;
 340
 341            Some(CreateBillingPortalSessionFlowData {
 342                type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
 343                subscription_update_confirm: Some(
 344                    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
 345                        subscription: subscription.stripe_subscription_id,
 346                        items: vec![
 347                            CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
 348                                id: subscription_item_to_update.to_string(),
 349                                price: Some(zed_pro_price_id.to_string()),
 350                                quantity: Some(1),
 351                            },
 352                        ],
 353                        discounts: None,
 354                    },
 355                ),
 356                ..Default::default()
 357            })
 358        }
 359        ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
 360            type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
 361            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 362                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 363                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 364                    return_url: format!(
 365                        "{}{path}",
 366                        app.config.zed_dot_dev_url(),
 367                        path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
 368                    ),
 369                }),
 370                ..Default::default()
 371            }),
 372            ..Default::default()
 373        }),
 374        ManageSubscriptionIntent::Cancel => {
 375            if subscription.kind == Some(SubscriptionKind::ZedFree) {
 376                return Err(Error::http(
 377                    StatusCode::BAD_REQUEST,
 378                    "free subscription cannot be canceled".into(),
 379                ));
 380            }
 381
 382            Some(CreateBillingPortalSessionFlowData {
 383                type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
 384                after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 385                    type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 386                    redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 387                        return_url: format!("{}/account", app.config.zed_dot_dev_url()),
 388                    }),
 389                    ..Default::default()
 390                }),
 391                subscription_cancel: Some(
 392                    stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
 393                        subscription: subscription.stripe_subscription_id,
 394                        retention: None,
 395                    },
 396                ),
 397                ..Default::default()
 398            })
 399        }
 400        ManageSubscriptionIntent::StopCancellation => unreachable!(),
 401    };
 402
 403    let mut params = CreateBillingPortalSession::new(customer_id);
 404    params.flow_data = flow;
 405    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
 406    params.return_url = Some(&return_url);
 407
 408    let session = BillingPortalSession::create(&stripe_client, params).await?;
 409
 410    Ok(Json(ManageBillingSubscriptionResponse {
 411        billing_portal_session_url: Some(session.url),
 412    }))
 413}
 414
 415#[derive(Debug, Deserialize)]
 416struct SyncBillingSubscriptionBody {
 417    github_user_id: i32,
 418}
 419
 420#[derive(Debug, Serialize)]
 421struct SyncBillingSubscriptionResponse {
 422    stripe_customer_id: String,
 423}
 424
 425async fn sync_billing_subscription(
 426    Extension(app): Extension<Arc<AppState>>,
 427    extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
 428) -> Result<Json<SyncBillingSubscriptionResponse>> {
 429    let Some(stripe_client) = app.stripe_client.clone() else {
 430        log::error!("failed to retrieve Stripe client");
 431        Err(Error::http(
 432            StatusCode::NOT_IMPLEMENTED,
 433            "not supported".into(),
 434        ))?
 435    };
 436
 437    let user = app
 438        .db
 439        .get_user_by_github_user_id(body.github_user_id)
 440        .await?
 441        .context("user not found")?;
 442
 443    let billing_customer = app
 444        .db
 445        .get_billing_customer_by_user_id(user.id)
 446        .await?
 447        .context("billing customer not found")?;
 448    let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
 449
 450    let subscriptions = stripe_client
 451        .list_subscriptions_for_customer(&stripe_customer_id)
 452        .await?;
 453
 454    for subscription in subscriptions {
 455        let subscription_id = subscription.id.clone();
 456
 457        sync_subscription(&app, &stripe_client, subscription)
 458            .await
 459            .with_context(|| {
 460                format!(
 461                    "failed to sync subscription {subscription_id} for user {}",
 462                    user.id,
 463                )
 464            })?;
 465    }
 466
 467    Ok(Json(SyncBillingSubscriptionResponse {
 468        stripe_customer_id: billing_customer.stripe_customer_id.clone(),
 469    }))
 470}
 471
 472/// The amount of time we wait in between each poll of Stripe events.
 473///
 474/// This value should strike a balance between:
 475///   1. Being short enough that we update quickly when something in Stripe changes
 476///   2. Being long enough that we don't eat into our rate limits.
 477///
 478/// As a point of reference, the Sequin folks say they have this at **500ms**:
 479///
 480/// > We poll the Stripe /events endpoint every 500ms per account
 481/// >
 482/// > — https://blog.sequinstream.com/events-not-webhooks/
 483const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
 484
 485/// The maximum number of events to return per page.
 486///
 487/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
 488///
 489/// > Limit can range between 1 and 100, and the default is 10.
 490const EVENTS_LIMIT_PER_PAGE: u64 = 100;
 491
 492/// The number of pages consisting entirely of already-processed events that we
 493/// will see before we stop retrieving events.
 494///
 495/// This is used to prevent over-fetching the Stripe events API for events we've
 496/// already seen and processed.
 497const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 498
 499/// Polls the Stripe events API periodically to reconcile the records in our
 500/// database with the data in Stripe.
 501pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
 502    let Some(real_stripe_client) = app.real_stripe_client.clone() else {
 503        log::warn!("failed to retrieve Stripe client");
 504        return;
 505    };
 506    let Some(stripe_client) = app.stripe_client.clone() else {
 507        log::warn!("failed to retrieve Stripe client");
 508        return;
 509    };
 510
 511    let executor = app.executor.clone();
 512    executor.spawn_detached({
 513        let executor = executor.clone();
 514        async move {
 515            loop {
 516                poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
 517                    .await
 518                    .log_err();
 519
 520                executor.sleep(POLL_EVENTS_INTERVAL).await;
 521            }
 522        }
 523    });
 524}
 525
 526async fn poll_stripe_events(
 527    app: &Arc<AppState>,
 528    rpc_server: &Arc<Server>,
 529    stripe_client: &Arc<dyn StripeClient>,
 530    real_stripe_client: &stripe::Client,
 531) -> anyhow::Result<()> {
 532    fn event_type_to_string(event_type: EventType) -> String {
 533        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
 534        // so we need to unquote it.
 535        event_type.to_string().trim_matches('"').to_string()
 536    }
 537
 538    let event_types = [
 539        EventType::CustomerCreated,
 540        EventType::CustomerUpdated,
 541        EventType::CustomerSubscriptionCreated,
 542        EventType::CustomerSubscriptionUpdated,
 543        EventType::CustomerSubscriptionPaused,
 544        EventType::CustomerSubscriptionResumed,
 545        EventType::CustomerSubscriptionDeleted,
 546    ]
 547    .into_iter()
 548    .map(event_type_to_string)
 549    .collect::<Vec<_>>();
 550
 551    let mut pages_of_already_processed_events = 0;
 552    let mut unprocessed_events = Vec::new();
 553
 554    log::info!(
 555        "Stripe events: starting retrieval for {}",
 556        event_types.join(", ")
 557    );
 558    let mut params = ListEvents::new();
 559    params.types = Some(event_types.clone());
 560    params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 561
 562    let mut event_pages = stripe::Event::list(&real_stripe_client, &params)
 563        .await?
 564        .paginate(params);
 565
 566    loop {
 567        let processed_event_ids = {
 568            let event_ids = event_pages
 569                .page
 570                .data
 571                .iter()
 572                .map(|event| event.id.as_str())
 573                .collect::<Vec<_>>();
 574            app.db
 575                .get_processed_stripe_events_by_event_ids(&event_ids)
 576                .await?
 577                .into_iter()
 578                .map(|event| event.stripe_event_id)
 579                .collect::<Vec<_>>()
 580        };
 581
 582        let mut processed_events_in_page = 0;
 583        let events_in_page = event_pages.page.data.len();
 584        for event in &event_pages.page.data {
 585            if processed_event_ids.contains(&event.id.to_string()) {
 586                processed_events_in_page += 1;
 587                log::debug!("Stripe events: already processed '{}', skipping", event.id);
 588            } else {
 589                unprocessed_events.push(event.clone());
 590            }
 591        }
 592
 593        if processed_events_in_page == events_in_page {
 594            pages_of_already_processed_events += 1;
 595        }
 596
 597        if event_pages.page.has_more {
 598            if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
 599            {
 600                log::info!(
 601                    "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
 602                );
 603                break;
 604            } else {
 605                log::info!("Stripe events: retrieving next page");
 606                event_pages = event_pages.next(&real_stripe_client).await?;
 607            }
 608        } else {
 609            break;
 610        }
 611    }
 612
 613    log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
 614
 615    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
 616    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
 617
 618    for event in unprocessed_events {
 619        let event_id = event.id.clone();
 620        let processed_event_params = CreateProcessedStripeEventParams {
 621            stripe_event_id: event.id.to_string(),
 622            stripe_event_type: event_type_to_string(event.type_),
 623            stripe_event_created_timestamp: event.created,
 624        };
 625
 626        // If the event has happened too far in the past, we don't want to
 627        // process it and risk overwriting other more-recent updates.
 628        //
 629        // 1 day was chosen arbitrarily. This could be made longer or shorter.
 630        let one_day = Duration::from_secs(24 * 60 * 60);
 631        let a_day_ago = Utc::now() - one_day;
 632        if a_day_ago.timestamp() > event.created {
 633            log::info!(
 634                "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
 635                event_id
 636            );
 637            app.db
 638                .create_processed_stripe_event(&processed_event_params)
 639                .await?;
 640
 641            continue;
 642        }
 643
 644        let process_result = match event.type_ {
 645            EventType::CustomerCreated | EventType::CustomerUpdated => {
 646                handle_customer_event(app, real_stripe_client, event).await
 647            }
 648            EventType::CustomerSubscriptionCreated
 649            | EventType::CustomerSubscriptionUpdated
 650            | EventType::CustomerSubscriptionPaused
 651            | EventType::CustomerSubscriptionResumed
 652            | EventType::CustomerSubscriptionDeleted => {
 653                handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
 654            }
 655            _ => Ok(()),
 656        };
 657
 658        if let Some(()) = process_result
 659            .with_context(|| format!("failed to process event {event_id} successfully"))
 660            .log_err()
 661        {
 662            app.db
 663                .create_processed_stripe_event(&processed_event_params)
 664                .await?;
 665        }
 666    }
 667
 668    Ok(())
 669}
 670
 671async fn handle_customer_event(
 672    app: &Arc<AppState>,
 673    _stripe_client: &stripe::Client,
 674    event: stripe::Event,
 675) -> anyhow::Result<()> {
 676    let EventObject::Customer(customer) = event.data.object else {
 677        bail!("unexpected event payload for {}", event.id);
 678    };
 679
 680    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 681
 682    let Some(email) = customer.email else {
 683        log::info!("Stripe customer has no email: skipping");
 684        return Ok(());
 685    };
 686
 687    let Some(user) = app.db.get_user_by_email(&email).await? else {
 688        log::info!("no user found for email: skipping");
 689        return Ok(());
 690    };
 691
 692    if let Some(existing_customer) = app
 693        .db
 694        .get_billing_customer_by_stripe_customer_id(&customer.id)
 695        .await?
 696    {
 697        app.db
 698            .update_billing_customer(
 699                existing_customer.id,
 700                &UpdateBillingCustomerParams {
 701                    // For now we just leave the information as-is, as it is not
 702                    // likely to change.
 703                    ..Default::default()
 704                },
 705            )
 706            .await?;
 707    } else {
 708        app.db
 709            .create_billing_customer(&CreateBillingCustomerParams {
 710                user_id: user.id,
 711                stripe_customer_id: customer.id.to_string(),
 712            })
 713            .await?;
 714    }
 715
 716    Ok(())
 717}
 718
 719async fn sync_subscription(
 720    app: &Arc<AppState>,
 721    stripe_client: &Arc<dyn StripeClient>,
 722    subscription: StripeSubscription,
 723) -> anyhow::Result<billing_customer::Model> {
 724    let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
 725        stripe_billing
 726            .determine_subscription_kind(&subscription)
 727            .await
 728    } else {
 729        None
 730    };
 731
 732    let billing_customer =
 733        find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
 734            .await?
 735            .context("billing customer not found")?;
 736
 737    if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
 738        if subscription.status == SubscriptionStatus::Trialing {
 739            let current_period_start =
 740                DateTime::from_timestamp(subscription.current_period_start, 0)
 741                    .context("No trial subscription period start")?;
 742
 743            app.db
 744                .update_billing_customer(
 745                    billing_customer.id,
 746                    &UpdateBillingCustomerParams {
 747                        trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
 748                        ..Default::default()
 749                    },
 750                )
 751                .await?;
 752        }
 753    }
 754
 755    let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
 756        && subscription
 757            .cancellation_details
 758            .as_ref()
 759            .and_then(|details| details.reason)
 760            .map_or(false, |reason| {
 761                reason == StripeCancellationDetailsReason::PaymentFailed
 762            });
 763
 764    if was_canceled_due_to_payment_failure {
 765        app.db
 766            .update_billing_customer(
 767                billing_customer.id,
 768                &UpdateBillingCustomerParams {
 769                    has_overdue_invoices: ActiveValue::set(true),
 770                    ..Default::default()
 771                },
 772            )
 773            .await?;
 774    }
 775
 776    if let Some(existing_subscription) = app
 777        .db
 778        .get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
 779        .await?
 780    {
 781        app.db
 782            .update_billing_subscription(
 783                existing_subscription.id,
 784                &UpdateBillingSubscriptionParams {
 785                    billing_customer_id: ActiveValue::set(billing_customer.id),
 786                    kind: ActiveValue::set(subscription_kind),
 787                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
 788                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
 789                    stripe_cancel_at: ActiveValue::set(
 790                        subscription
 791                            .cancel_at
 792                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 793                            .map(|time| time.naive_utc()),
 794                    ),
 795                    stripe_cancellation_reason: ActiveValue::set(
 796                        subscription
 797                            .cancellation_details
 798                            .and_then(|details| details.reason)
 799                            .map(|reason| reason.into()),
 800                    ),
 801                    stripe_current_period_start: ActiveValue::set(Some(
 802                        subscription.current_period_start,
 803                    )),
 804                    stripe_current_period_end: ActiveValue::set(Some(
 805                        subscription.current_period_end,
 806                    )),
 807                },
 808            )
 809            .await?;
 810    } else {
 811        if let Some(existing_subscription) = app
 812            .db
 813            .get_active_billing_subscription(billing_customer.user_id)
 814            .await?
 815        {
 816            if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
 817                && subscription_kind == Some(SubscriptionKind::ZedProTrial)
 818            {
 819                let stripe_subscription_id = StripeSubscriptionId(
 820                    existing_subscription.stripe_subscription_id.clone().into(),
 821                );
 822
 823                stripe_client
 824                    .cancel_subscription(&stripe_subscription_id)
 825                    .await?;
 826            } else {
 827                // If the user already has an active billing subscription, ignore the
 828                // event and return an `Ok` to signal that it was processed
 829                // successfully.
 830                //
 831                // There is the possibility that this could cause us to not create a
 832                // subscription in the following scenario:
 833                //
 834                //   1. User has an active subscription A
 835                //   2. User cancels subscription A
 836                //   3. User creates a new subscription B
 837                //   4. We process the new subscription B before the cancellation of subscription A
 838                //   5. User ends up with no subscriptions
 839                //
 840                // In theory this situation shouldn't arise as we try to process the events in the order they occur.
 841
 842                log::info!(
 843                    "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
 844                    user_id = billing_customer.user_id,
 845                    subscription_id = subscription.id
 846                );
 847                return Ok(billing_customer);
 848            }
 849        }
 850
 851        app.db
 852            .create_billing_subscription(&CreateBillingSubscriptionParams {
 853                billing_customer_id: billing_customer.id,
 854                kind: subscription_kind,
 855                stripe_subscription_id: subscription.id.to_string(),
 856                stripe_subscription_status: subscription.status.into(),
 857                stripe_cancellation_reason: subscription
 858                    .cancellation_details
 859                    .and_then(|details| details.reason)
 860                    .map(|reason| reason.into()),
 861                stripe_current_period_start: Some(subscription.current_period_start),
 862                stripe_current_period_end: Some(subscription.current_period_end),
 863            })
 864            .await?;
 865    }
 866
 867    if let Some(stripe_billing) = app.stripe_billing.as_ref() {
 868        if subscription.status == SubscriptionStatus::Canceled
 869            || subscription.status == SubscriptionStatus::Paused
 870        {
 871            let already_has_active_billing_subscription = app
 872                .db
 873                .has_active_billing_subscription(billing_customer.user_id)
 874                .await?;
 875            if !already_has_active_billing_subscription {
 876                let stripe_customer_id =
 877                    StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
 878
 879                stripe_billing
 880                    .subscribe_to_zed_free(stripe_customer_id)
 881                    .await?;
 882            }
 883        }
 884    }
 885
 886    Ok(billing_customer)
 887}
 888
 889async fn handle_customer_subscription_event(
 890    app: &Arc<AppState>,
 891    rpc_server: &Arc<Server>,
 892    stripe_client: &Arc<dyn StripeClient>,
 893    event: stripe::Event,
 894) -> anyhow::Result<()> {
 895    let EventObject::Subscription(subscription) = event.data.object else {
 896        bail!("unexpected event payload for {}", event.id);
 897    };
 898
 899    log::info!("handling Stripe {} event: {}", event.type_, event.id);
 900
 901    let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
 902
 903    // When the user's subscription changes, push down any changes to their plan.
 904    rpc_server
 905        .update_plan_for_user(billing_customer.user_id)
 906        .await
 907        .trace_err();
 908
 909    // When the user's subscription changes, we want to refresh their LLM tokens
 910    // to either grant/revoke access.
 911    rpc_server
 912        .refresh_llm_tokens_for_user(billing_customer.user_id)
 913        .await;
 914
 915    Ok(())
 916}
 917
 918impl From<SubscriptionStatus> for StripeSubscriptionStatus {
 919    fn from(value: SubscriptionStatus) -> Self {
 920        match value {
 921            SubscriptionStatus::Incomplete => Self::Incomplete,
 922            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
 923            SubscriptionStatus::Trialing => Self::Trialing,
 924            SubscriptionStatus::Active => Self::Active,
 925            SubscriptionStatus::PastDue => Self::PastDue,
 926            SubscriptionStatus::Canceled => Self::Canceled,
 927            SubscriptionStatus::Unpaid => Self::Unpaid,
 928            SubscriptionStatus::Paused => Self::Paused,
 929        }
 930    }
 931}
 932
 933impl From<CancellationDetailsReason> for StripeCancellationReason {
 934    fn from(value: CancellationDetailsReason) -> Self {
 935        match value {
 936            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
 937            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
 938            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
 939        }
 940    }
 941}
 942
 943/// Finds or creates a billing customer using the provided customer.
 944pub async fn find_or_create_billing_customer(
 945    app: &Arc<AppState>,
 946    stripe_client: &dyn StripeClient,
 947    customer_id: &StripeCustomerId,
 948) -> anyhow::Result<Option<billing_customer::Model>> {
 949    // If we already have a billing customer record associated with the Stripe customer,
 950    // there's nothing more we need to do.
 951    if let Some(billing_customer) = app
 952        .db
 953        .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
 954        .await?
 955    {
 956        return Ok(Some(billing_customer));
 957    }
 958
 959    let customer = stripe_client.get_customer(customer_id).await?;
 960
 961    let Some(email) = customer.email else {
 962        return Ok(None);
 963    };
 964
 965    let Some(user) = app.db.get_user_by_email(&email).await? else {
 966        return Ok(None);
 967    };
 968
 969    let billing_customer = app
 970        .db
 971        .create_billing_customer(&CreateBillingCustomerParams {
 972            user_id: user.id,
 973            stripe_customer_id: customer.id.to_string(),
 974        })
 975        .await?;
 976
 977    Ok(Some(billing_customer))
 978}
 979
 980const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
 981
 982pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
 983    let Some(stripe_billing) = app.stripe_billing.clone() else {
 984        log::warn!("failed to retrieve Stripe billing object");
 985        return;
 986    };
 987    let Some(llm_db) = app.llm_db.clone() else {
 988        log::warn!("failed to retrieve LLM database");
 989        return;
 990    };
 991
 992    let executor = app.executor.clone();
 993    executor.spawn_detached({
 994        let executor = executor.clone();
 995        async move {
 996            loop {
 997                sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
 998                    .await
 999                    .context("failed to sync LLM request usage to Stripe")
1000                    .trace_err();
1001                executor
1002                    .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1003                    .await;
1004            }
1005        }
1006    });
1007}
1008
1009async fn sync_model_request_usage_with_stripe(
1010    app: &Arc<AppState>,
1011    llm_db: &Arc<LlmDatabase>,
1012    stripe_billing: &Arc<StripeBilling>,
1013) -> anyhow::Result<()> {
1014    log::info!("Stripe usage sync: Starting");
1015    let started_at = Utc::now();
1016
1017    let staff_users = app.db.get_staff_users().await?;
1018    let staff_user_ids = staff_users
1019        .iter()
1020        .map(|user| user.id)
1021        .collect::<HashSet<UserId>>();
1022
1023    let usage_meters = llm_db
1024        .get_current_subscription_usage_meters(Utc::now())
1025        .await?;
1026    let mut usage_meters_by_user_id =
1027        HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
1028    for (usage_meter, usage) in usage_meters {
1029        let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
1030        meters.push(usage_meter);
1031    }
1032
1033    log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
1034    let get_zed_pro_subscriptions_started_at = Utc::now();
1035    let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
1036    log::info!(
1037        "Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}",
1038        billing_subscriptions.len(),
1039        Utc::now() - get_zed_pro_subscriptions_started_at
1040    );
1041
1042    let claude_sonnet_4 = stripe_billing
1043        .find_price_by_lookup_key("claude-sonnet-4-requests")
1044        .await?;
1045    let claude_sonnet_4_max = stripe_billing
1046        .find_price_by_lookup_key("claude-sonnet-4-requests-max")
1047        .await?;
1048    let claude_opus_4 = stripe_billing
1049        .find_price_by_lookup_key("claude-opus-4-requests")
1050        .await?;
1051    let claude_opus_4_max = stripe_billing
1052        .find_price_by_lookup_key("claude-opus-4-requests-max")
1053        .await?;
1054    let claude_3_5_sonnet = stripe_billing
1055        .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1056        .await?;
1057    let claude_3_7_sonnet = stripe_billing
1058        .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1059        .await?;
1060    let claude_3_7_sonnet_max = stripe_billing
1061        .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1062        .await?;
1063
1064    let model_mode_combinations = [
1065        ("claude-opus-4", CompletionMode::Max),
1066        ("claude-opus-4", CompletionMode::Normal),
1067        ("claude-sonnet-4", CompletionMode::Max),
1068        ("claude-sonnet-4", CompletionMode::Normal),
1069        ("claude-3-7-sonnet", CompletionMode::Max),
1070        ("claude-3-7-sonnet", CompletionMode::Normal),
1071        ("claude-3-5-sonnet", CompletionMode::Normal),
1072    ];
1073
1074    let billing_subscription_count = billing_subscriptions.len();
1075
1076    log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
1077
1078    for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
1079        maybe!(async {
1080            if staff_user_ids.contains(&user_id) {
1081                return anyhow::Ok(());
1082            }
1083
1084            let stripe_customer_id =
1085                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
1086            let stripe_subscription_id =
1087                StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
1088
1089            let usage_meters = usage_meters_by_user_id.get(&user_id);
1090
1091            for (model, mode) in &model_mode_combinations {
1092                let Ok(model) =
1093                    llm_db.model(LanguageModelProvider::Anthropic, model)
1094                else {
1095                    log::warn!("Failed to load model for user {user_id}: {model}");
1096                    continue;
1097                };
1098
1099                let (price, meter_event_name) = match model.name.as_str() {
1100                    "claude-opus-4" => match mode {
1101                        CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
1102                        CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
1103                    },
1104                    "claude-sonnet-4" => match mode {
1105                        CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
1106                        CompletionMode::Max => {
1107                            (&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
1108                        }
1109                    },
1110                    "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1111                    "claude-3-7-sonnet" => match mode {
1112                        CompletionMode::Normal => {
1113                            (&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
1114                        }
1115                        CompletionMode::Max => {
1116                            (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1117                        }
1118                    },
1119                    model_name => {
1120                        bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1121                    }
1122                };
1123
1124                let model_requests = usage_meters
1125                    .and_then(|usage_meters| {
1126                        usage_meters
1127                            .iter()
1128                            .find(|meter| meter.model_id == model.id && meter.mode == *mode)
1129                    })
1130                    .map(|usage_meter| usage_meter.requests)
1131                    .unwrap_or(0);
1132
1133                if model_requests > 0 {
1134                    stripe_billing
1135                        .subscribe_to_price(&stripe_subscription_id, price)
1136                        .await?;
1137                }
1138
1139                stripe_billing
1140                    .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
1141                    .await
1142                    .with_context(|| {
1143                        format!(
1144                            "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
1145                        )
1146                    })?;
1147            }
1148
1149            Ok(())
1150        })
1151        .await
1152        .log_err();
1153    }
1154
1155    log::info!(
1156        "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
1157        Utc::now() - started_at
1158    );
1159
1160    Ok(())
1161}