billing.rs

   1use anyhow::{Context, anyhow, bail};
   2use axum::{
   3    Extension, Json, Router,
   4    extract::{self, Query},
   5    routing::{get, post},
   6};
   7use chrono::{DateTime, SecondsFormat, Utc};
   8use collections::HashSet;
   9use reqwest::StatusCode;
  10use sea_orm::ActiveValue;
  11use serde::{Deserialize, Serialize};
  12use serde_json::json;
  13use std::{str::FromStr, sync::Arc, time::Duration};
  14use stripe::{
  15    BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
  16    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
  17    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
  18    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
  19    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
  20    CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
  21    EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId,
  22    SubscriptionStatus,
  23};
  24use util::{ResultExt, maybe};
  25
  26use crate::api::events::SnowflakeRow;
  27use crate::db::billing_subscription::{
  28    StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
  29};
  30use crate::llm::db::subscription_usage_meter::CompletionMode;
  31use crate::llm::{
  32    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
  33};
  34use crate::rpc::{ResultExt as _, Server};
  35use crate::{AppState, Cents, Error, Result};
  36use crate::{db::UserId, llm::db::LlmDatabase};
  37use crate::{
  38    db::{
  39        BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams,
  40        CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
  41        UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer,
  42    },
  43    stripe_billing::StripeBilling,
  44};
  45
  46pub fn router() -> Router {
  47    Router::new()
  48        .route(
  49            "/billing/preferences",
  50            get(get_billing_preferences).put(update_billing_preferences),
  51        )
  52        .route(
  53            "/billing/subscriptions",
  54            get(list_billing_subscriptions).post(create_billing_subscription),
  55        )
  56        .route(
  57            "/billing/subscriptions/manage",
  58            post(manage_billing_subscription),
  59        )
  60        .route(
  61            "/billing/subscriptions/migrate",
  62            post(migrate_to_new_billing),
  63        )
  64        .route(
  65            "/billing/subscriptions/sync",
  66            post(sync_billing_subscription),
  67        )
  68        .route("/billing/monthly_spend", get(get_monthly_spend))
  69        .route("/billing/usage", get(get_current_usage))
  70}
  71
  72#[derive(Debug, Deserialize)]
  73struct GetBillingPreferencesParams {
  74    github_user_id: i32,
  75}
  76
  77#[derive(Debug, Serialize)]
  78struct BillingPreferencesResponse {
  79    trial_started_at: Option<String>,
  80    max_monthly_llm_usage_spending_in_cents: i32,
  81    model_request_overages_enabled: bool,
  82    model_request_overages_spend_limit_in_cents: i32,
  83}
  84
  85async fn get_billing_preferences(
  86    Extension(app): Extension<Arc<AppState>>,
  87    Query(params): Query<GetBillingPreferencesParams>,
  88) -> Result<Json<BillingPreferencesResponse>> {
  89    let user = app
  90        .db
  91        .get_user_by_github_user_id(params.github_user_id)
  92        .await?
  93        .ok_or_else(|| anyhow!("user not found"))?;
  94
  95    let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
  96    let preferences = app.db.get_billing_preferences(user.id).await?;
  97
  98    Ok(Json(BillingPreferencesResponse {
  99        trial_started_at: billing_customer
 100            .and_then(|billing_customer| billing_customer.trial_started_at)
 101            .map(|trial_started_at| {
 102                trial_started_at
 103                    .and_utc()
 104                    .to_rfc3339_opts(SecondsFormat::Millis, true)
 105            }),
 106        max_monthly_llm_usage_spending_in_cents: preferences
 107            .as_ref()
 108            .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
 109                preferences.max_monthly_llm_usage_spending_in_cents
 110            }),
 111        model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
 112            preferences.model_request_overages_enabled
 113        }),
 114        model_request_overages_spend_limit_in_cents: preferences
 115            .as_ref()
 116            .map_or(0, |preferences| {
 117                preferences.model_request_overages_spend_limit_in_cents
 118            }),
 119    }))
 120}
 121
 122#[derive(Debug, Deserialize)]
 123struct UpdateBillingPreferencesBody {
 124    github_user_id: i32,
 125    #[serde(default)]
 126    max_monthly_llm_usage_spending_in_cents: i32,
 127    #[serde(default)]
 128    model_request_overages_enabled: bool,
 129    #[serde(default)]
 130    model_request_overages_spend_limit_in_cents: i32,
 131}
 132
 133async fn update_billing_preferences(
 134    Extension(app): Extension<Arc<AppState>>,
 135    Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
 136    extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
 137) -> Result<Json<BillingPreferencesResponse>> {
 138    let user = app
 139        .db
 140        .get_user_by_github_user_id(body.github_user_id)
 141        .await?
 142        .ok_or_else(|| anyhow!("user not found"))?;
 143
 144    let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 145
 146    let max_monthly_llm_usage_spending_in_cents =
 147        body.max_monthly_llm_usage_spending_in_cents.max(0);
 148    let model_request_overages_spend_limit_in_cents =
 149        body.model_request_overages_spend_limit_in_cents.max(0);
 150
 151    let billing_preferences =
 152        if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
 153            app.db
 154                .update_billing_preferences(
 155                    user.id,
 156                    &UpdateBillingPreferencesParams {
 157                        max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
 158                            max_monthly_llm_usage_spending_in_cents,
 159                        ),
 160                        model_request_overages_enabled: ActiveValue::set(
 161                            body.model_request_overages_enabled,
 162                        ),
 163                        model_request_overages_spend_limit_in_cents: ActiveValue::set(
 164                            model_request_overages_spend_limit_in_cents,
 165                        ),
 166                    },
 167                )
 168                .await?
 169        } else {
 170            app.db
 171                .create_billing_preferences(
 172                    user.id,
 173                    &crate::db::CreateBillingPreferencesParams {
 174                        max_monthly_llm_usage_spending_in_cents,
 175                        model_request_overages_enabled: body.model_request_overages_enabled,
 176                        model_request_overages_spend_limit_in_cents,
 177                    },
 178                )
 179                .await?
 180        };
 181
 182    SnowflakeRow::new(
 183        "Billing Preferences Updated",
 184        Some(user.metrics_id),
 185        user.admin,
 186        None,
 187        json!({
 188            "user_id": user.id,
 189            "model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
 190            "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
 191            "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
 192        }),
 193    )
 194    .write(&app.kinesis_client, &app.config.kinesis_stream)
 195    .await
 196    .log_err();
 197
 198    rpc_server.refresh_llm_tokens_for_user(user.id).await;
 199
 200    Ok(Json(BillingPreferencesResponse {
 201        trial_started_at: billing_customer
 202            .and_then(|billing_customer| billing_customer.trial_started_at)
 203            .map(|trial_started_at| {
 204                trial_started_at
 205                    .and_utc()
 206                    .to_rfc3339_opts(SecondsFormat::Millis, true)
 207            }),
 208        max_monthly_llm_usage_spending_in_cents: billing_preferences
 209            .max_monthly_llm_usage_spending_in_cents,
 210        model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
 211        model_request_overages_spend_limit_in_cents: billing_preferences
 212            .model_request_overages_spend_limit_in_cents,
 213    }))
 214}
 215
 216#[derive(Debug, Deserialize)]
 217struct ListBillingSubscriptionsParams {
 218    github_user_id: i32,
 219}
 220
 221#[derive(Debug, Serialize)]
 222struct BillingSubscriptionJson {
 223    id: BillingSubscriptionId,
 224    name: String,
 225    status: StripeSubscriptionStatus,
 226    trial_end_at: Option<String>,
 227    cancel_at: Option<String>,
 228    /// Whether this subscription can be canceled.
 229    is_cancelable: bool,
 230}
 231
 232#[derive(Debug, Serialize)]
 233struct ListBillingSubscriptionsResponse {
 234    subscriptions: Vec<BillingSubscriptionJson>,
 235}
 236
 237async fn list_billing_subscriptions(
 238    Extension(app): Extension<Arc<AppState>>,
 239    Query(params): Query<ListBillingSubscriptionsParams>,
 240) -> Result<Json<ListBillingSubscriptionsResponse>> {
 241    let user = app
 242        .db
 243        .get_user_by_github_user_id(params.github_user_id)
 244        .await?
 245        .ok_or_else(|| anyhow!("user not found"))?;
 246
 247    let subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 248
 249    Ok(Json(ListBillingSubscriptionsResponse {
 250        subscriptions: subscriptions
 251            .into_iter()
 252            .map(|subscription| BillingSubscriptionJson {
 253                id: subscription.id,
 254                name: match subscription.kind {
 255                    Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
 256                    Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
 257                    Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
 258                    None => "Zed LLM Usage".to_string(),
 259                },
 260                status: subscription.stripe_subscription_status,
 261                trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
 262                    maybe!({
 263                        let end_at = subscription.stripe_current_period_end?;
 264                        let end_at = DateTime::from_timestamp(end_at, 0)?;
 265
 266                        Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
 267                    })
 268                } else {
 269                    None
 270                },
 271                cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
 272                    cancel_at
 273                        .and_utc()
 274                        .to_rfc3339_opts(SecondsFormat::Millis, true)
 275                }),
 276                is_cancelable: subscription.stripe_subscription_status.is_cancelable()
 277                    && subscription.stripe_cancel_at.is_none(),
 278            })
 279            .collect(),
 280    }))
 281}
 282
 283#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
 284#[serde(rename_all = "snake_case")]
 285enum ProductCode {
 286    ZedPro,
 287    ZedProTrial,
 288    ZedFree,
 289}
 290
 291#[derive(Debug, Deserialize)]
 292struct CreateBillingSubscriptionBody {
 293    github_user_id: i32,
 294    product: ProductCode,
 295}
 296
 297#[derive(Debug, Serialize)]
 298struct CreateBillingSubscriptionResponse {
 299    checkout_session_url: String,
 300}
 301
 302/// Initiates a Stripe Checkout session for creating a billing subscription.
 303async fn create_billing_subscription(
 304    Extension(app): Extension<Arc<AppState>>,
 305    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 306) -> Result<Json<CreateBillingSubscriptionResponse>> {
 307    let user = app
 308        .db
 309        .get_user_by_github_user_id(body.github_user_id)
 310        .await?
 311        .ok_or_else(|| anyhow!("user not found"))?;
 312
 313    let Some(stripe_client) = app.stripe_client.clone() else {
 314        log::error!("failed to retrieve Stripe client");
 315        Err(Error::http(
 316            StatusCode::NOT_IMPLEMENTED,
 317            "not supported".into(),
 318        ))?
 319    };
 320    let Some(stripe_billing) = app.stripe_billing.clone() else {
 321        log::error!("failed to retrieve Stripe billing object");
 322        Err(Error::http(
 323            StatusCode::NOT_IMPLEMENTED,
 324            "not supported".into(),
 325        ))?
 326    };
 327
 328    if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
 329        let is_checkout_allowed = body.product == ProductCode::ZedProTrial
 330            && existing_subscription.kind == Some(SubscriptionKind::ZedFree);
 331
 332        if !is_checkout_allowed {
 333            return Err(Error::http(
 334                StatusCode::CONFLICT,
 335                "user already has an active subscription".into(),
 336            ));
 337        }
 338    }
 339
 340    let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
 341    if let Some(existing_billing_customer) = &existing_billing_customer {
 342        if existing_billing_customer.has_overdue_invoices {
 343            return Err(Error::http(
 344                StatusCode::PAYMENT_REQUIRED,
 345                "user has overdue invoices".into(),
 346            ));
 347        }
 348    }
 349
 350    let customer_id = if let Some(existing_customer) = &existing_billing_customer {
 351        CustomerId::from_str(&existing_customer.stripe_customer_id)
 352            .context("failed to parse customer ID")?
 353    } else {
 354        let existing_customer = if let Some(email) = user.email_address.as_deref() {
 355            let customers = Customer::list(
 356                &stripe_client,
 357                &stripe::ListCustomers {
 358                    email: Some(email),
 359                    ..Default::default()
 360                },
 361            )
 362            .await?;
 363
 364            customers.data.first().cloned()
 365        } else {
 366            None
 367        };
 368
 369        if let Some(existing_customer) = existing_customer {
 370            existing_customer.id
 371        } else {
 372            let customer = Customer::create(
 373                &stripe_client,
 374                CreateCustomer {
 375                    email: user.email_address.as_deref(),
 376                    ..Default::default()
 377                },
 378            )
 379            .await?;
 380
 381            customer.id
 382        }
 383    };
 384
 385    let success_url = format!(
 386        "{}/account?checkout_complete=1",
 387        app.config.zed_dot_dev_url()
 388    );
 389
 390    let checkout_session_url = match body.product {
 391        ProductCode::ZedPro => {
 392            stripe_billing
 393                .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
 394                .await?
 395        }
 396        ProductCode::ZedProTrial => {
 397            if let Some(existing_billing_customer) = &existing_billing_customer {
 398                if existing_billing_customer.trial_started_at.is_some() {
 399                    return Err(Error::http(
 400                        StatusCode::FORBIDDEN,
 401                        "user already used free trial".into(),
 402                    ));
 403                }
 404            }
 405
 406            let feature_flags = app.db.get_user_flags(user.id).await?;
 407
 408            stripe_billing
 409                .checkout_with_zed_pro_trial(
 410                    customer_id,
 411                    &user.github_login,
 412                    feature_flags,
 413                    &success_url,
 414                )
 415                .await?
 416        }
 417        ProductCode::ZedFree => {
 418            stripe_billing
 419                .checkout_with_zed_free(customer_id, &user.github_login, &success_url)
 420                .await?
 421        }
 422    };
 423
 424    Ok(Json(CreateBillingSubscriptionResponse {
 425        checkout_session_url,
 426    }))
 427}
 428
 429#[derive(Debug, PartialEq, Deserialize)]
 430#[serde(rename_all = "snake_case")]
 431enum ManageSubscriptionIntent {
 432    /// The user intends to manage their subscription.
 433    ///
 434    /// This will open the Stripe billing portal without putting the user in a specific flow.
 435    ManageSubscription,
 436    /// The user intends to update their payment method.
 437    UpdatePaymentMethod,
 438    /// The user intends to upgrade to Zed Pro.
 439    UpgradeToPro,
 440    /// The user intends to cancel their subscription.
 441    Cancel,
 442    /// The user intends to stop the cancellation of their subscription.
 443    StopCancellation,
 444}
 445
 446#[derive(Debug, Deserialize)]
 447struct ManageBillingSubscriptionBody {
 448    github_user_id: i32,
 449    intent: ManageSubscriptionIntent,
 450    /// The ID of the subscription to manage.
 451    subscription_id: BillingSubscriptionId,
 452    redirect_to: Option<String>,
 453}
 454
 455#[derive(Debug, Serialize)]
 456struct ManageBillingSubscriptionResponse {
 457    billing_portal_session_url: Option<String>,
 458}
 459
 460/// Initiates a Stripe customer portal session for managing a billing subscription.
 461async fn manage_billing_subscription(
 462    Extension(app): Extension<Arc<AppState>>,
 463    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
 464) -> Result<Json<ManageBillingSubscriptionResponse>> {
 465    let user = app
 466        .db
 467        .get_user_by_github_user_id(body.github_user_id)
 468        .await?
 469        .ok_or_else(|| anyhow!("user not found"))?;
 470
 471    let Some(stripe_client) = app.stripe_client.clone() else {
 472        log::error!("failed to retrieve Stripe client");
 473        Err(Error::http(
 474            StatusCode::NOT_IMPLEMENTED,
 475            "not supported".into(),
 476        ))?
 477    };
 478
 479    let Some(stripe_billing) = app.stripe_billing.clone() else {
 480        log::error!("failed to retrieve Stripe billing object");
 481        Err(Error::http(
 482            StatusCode::NOT_IMPLEMENTED,
 483            "not supported".into(),
 484        ))?
 485    };
 486
 487    let customer = app
 488        .db
 489        .get_billing_customer_by_user_id(user.id)
 490        .await?
 491        .ok_or_else(|| anyhow!("billing customer not found"))?;
 492    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
 493        .context("failed to parse customer ID")?;
 494
 495    let subscription = app
 496        .db
 497        .get_billing_subscription_by_id(body.subscription_id)
 498        .await?
 499        .ok_or_else(|| anyhow!("subscription not found"))?;
 500    let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
 501        .context("failed to parse subscription ID")?;
 502
 503    if body.intent == ManageSubscriptionIntent::StopCancellation {
 504        let updated_stripe_subscription = Subscription::update(
 505            &stripe_client,
 506            &subscription_id,
 507            stripe::UpdateSubscription {
 508                cancel_at_period_end: Some(false),
 509                ..Default::default()
 510            },
 511        )
 512        .await?;
 513
 514        app.db
 515            .update_billing_subscription(
 516                subscription.id,
 517                &UpdateBillingSubscriptionParams {
 518                    stripe_cancel_at: ActiveValue::set(
 519                        updated_stripe_subscription
 520                            .cancel_at
 521                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
 522                            .map(|time| time.naive_utc()),
 523                    ),
 524                    ..Default::default()
 525                },
 526            )
 527            .await?;
 528
 529        return Ok(Json(ManageBillingSubscriptionResponse {
 530            billing_portal_session_url: None,
 531        }));
 532    }
 533
 534    let flow = match body.intent {
 535        ManageSubscriptionIntent::ManageSubscription => None,
 536        ManageSubscriptionIntent::UpgradeToPro => {
 537            let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
 538            let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
 539
 540            let stripe_subscription =
 541                Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
 542
 543            let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing
 544                && stripe_subscription.items.data.iter().any(|item| {
 545                    item.price
 546                        .as_ref()
 547                        .map_or(false, |price| price.id == zed_pro_price_id)
 548                });
 549            if is_on_zed_pro_trial {
 550                let payment_methods = PaymentMethod::list(
 551                    &stripe_client,
 552                    &stripe::ListPaymentMethods {
 553                        customer: Some(stripe_subscription.customer.id()),
 554                        ..Default::default()
 555                    },
 556                )
 557                .await?;
 558
 559                let has_payment_method = !payment_methods.data.is_empty();
 560                if !has_payment_method {
 561                    return Err(Error::http(
 562                        StatusCode::BAD_REQUEST,
 563                        "missing payment method".into(),
 564                    ));
 565                }
 566
 567                // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
 568                Subscription::update(
 569                    &stripe_client,
 570                    &stripe_subscription.id,
 571                    stripe::UpdateSubscription {
 572                        trial_end: Some(stripe::Scheduled::now()),
 573                        ..Default::default()
 574                    },
 575                )
 576                .await?;
 577
 578                return Ok(Json(ManageBillingSubscriptionResponse {
 579                    billing_portal_session_url: None,
 580                }));
 581            }
 582
 583            let subscription_item_to_update = stripe_subscription
 584                .items
 585                .data
 586                .iter()
 587                .find_map(|item| {
 588                    let price = item.price.as_ref()?;
 589
 590                    if price.id == zed_free_price_id {
 591                        Some(item.id.clone())
 592                    } else {
 593                        None
 594                    }
 595                })
 596                .ok_or_else(|| anyhow!("No subscription item to update"))?;
 597
 598            Some(CreateBillingPortalSessionFlowData {
 599                type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
 600                subscription_update_confirm: Some(
 601                    CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
 602                        subscription: subscription.stripe_subscription_id,
 603                        items: vec![
 604                            CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
 605                                id: subscription_item_to_update.to_string(),
 606                                price: Some(zed_pro_price_id.to_string()),
 607                                quantity: Some(1),
 608                            },
 609                        ],
 610                        discounts: None,
 611                    },
 612                ),
 613                ..Default::default()
 614            })
 615        }
 616        ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
 617            type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
 618            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 619                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 620                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 621                    return_url: format!(
 622                        "{}{path}",
 623                        app.config.zed_dot_dev_url(),
 624                        path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
 625                    ),
 626                }),
 627                ..Default::default()
 628            }),
 629            ..Default::default()
 630        }),
 631        ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
 632            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
 633            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
 634                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
 635                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
 636                    return_url: format!("{}/account", app.config.zed_dot_dev_url()),
 637                }),
 638                ..Default::default()
 639            }),
 640            subscription_cancel: Some(
 641                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
 642                    subscription: subscription.stripe_subscription_id,
 643                    retention: None,
 644                },
 645            ),
 646            ..Default::default()
 647        }),
 648        ManageSubscriptionIntent::StopCancellation => unreachable!(),
 649    };
 650
 651    let mut params = CreateBillingPortalSession::new(customer_id);
 652    params.flow_data = flow;
 653    let return_url = format!("{}/account", app.config.zed_dot_dev_url());
 654    params.return_url = Some(&return_url);
 655
 656    let session = BillingPortalSession::create(&stripe_client, params).await?;
 657
 658    Ok(Json(ManageBillingSubscriptionResponse {
 659        billing_portal_session_url: Some(session.url),
 660    }))
 661}
 662
 663#[derive(Debug, Deserialize)]
 664struct MigrateToNewBillingBody {
 665    github_user_id: i32,
 666}
 667
 668#[derive(Debug, Serialize)]
 669struct MigrateToNewBillingResponse {
 670    /// The ID of the subscription that was canceled.
 671    canceled_subscription_id: Option<String>,
 672}
 673
 674async fn migrate_to_new_billing(
 675    Extension(app): Extension<Arc<AppState>>,
 676    extract::Json(body): extract::Json<MigrateToNewBillingBody>,
 677) -> Result<Json<MigrateToNewBillingResponse>> {
 678    let Some(stripe_client) = app.stripe_client.clone() else {
 679        log::error!("failed to retrieve Stripe client");
 680        Err(Error::http(
 681            StatusCode::NOT_IMPLEMENTED,
 682            "not supported".into(),
 683        ))?
 684    };
 685
 686    let user = app
 687        .db
 688        .get_user_by_github_user_id(body.github_user_id)
 689        .await?
 690        .ok_or_else(|| anyhow!("user not found"))?;
 691
 692    let old_billing_subscriptions_by_user = app
 693        .db
 694        .get_active_billing_subscriptions(HashSet::from_iter([user.id]))
 695        .await?;
 696
 697    let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
 698        old_billing_subscriptions_by_user.get(&user.id)
 699    {
 700        let stripe_subscription_id = billing_subscription
 701            .stripe_subscription_id
 702            .parse::<stripe::SubscriptionId>()
 703            .context("failed to parse Stripe subscription ID from database")?;
 704
 705        Subscription::cancel(
 706            &stripe_client,
 707            &stripe_subscription_id,
 708            stripe::CancelSubscription {
 709                invoice_now: Some(true),
 710                ..Default::default()
 711            },
 712        )
 713        .await?;
 714
 715        Some(stripe_subscription_id)
 716    } else {
 717        None
 718    };
 719
 720    let all_feature_flags = app.db.list_feature_flags().await?;
 721    let user_feature_flags = app.db.get_user_flags(user.id).await?;
 722
 723    for feature_flag in ["new-billing", "assistant2"] {
 724        let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
 725        if already_in_feature_flag {
 726            continue;
 727        }
 728
 729        let feature_flag = all_feature_flags
 730            .iter()
 731            .find(|flag| flag.flag == feature_flag)
 732            .context("failed to find feature flag: {feature_flag:?}")?;
 733
 734        app.db.add_user_flag(user.id, feature_flag.id).await?;
 735    }
 736
 737    Ok(Json(MigrateToNewBillingResponse {
 738        canceled_subscription_id: canceled_subscription_id
 739            .map(|subscription_id| subscription_id.to_string()),
 740    }))
 741}
 742
 743#[derive(Debug, Deserialize)]
 744struct SyncBillingSubscriptionBody {
 745    github_user_id: i32,
 746}
 747
 748#[derive(Debug, Serialize)]
 749struct SyncBillingSubscriptionResponse {
 750    stripe_customer_id: String,
 751}
 752
 753async fn sync_billing_subscription(
 754    Extension(app): Extension<Arc<AppState>>,
 755    extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
 756) -> Result<Json<SyncBillingSubscriptionResponse>> {
 757    let Some(stripe_client) = app.stripe_client.clone() else {
 758        log::error!("failed to retrieve Stripe client");
 759        Err(Error::http(
 760            StatusCode::NOT_IMPLEMENTED,
 761            "not supported".into(),
 762        ))?
 763    };
 764
 765    let user = app
 766        .db
 767        .get_user_by_github_user_id(body.github_user_id)
 768        .await?
 769        .ok_or_else(|| anyhow!("user not found"))?;
 770
 771    let billing_customer = app
 772        .db
 773        .get_billing_customer_by_user_id(user.id)
 774        .await?
 775        .ok_or_else(|| anyhow!("billing customer not found"))?;
 776    let stripe_customer_id = billing_customer
 777        .stripe_customer_id
 778        .parse::<stripe::CustomerId>()
 779        .context("failed to parse Stripe customer ID from database")?;
 780
 781    let subscriptions = Subscription::list(
 782        &stripe_client,
 783        &stripe::ListSubscriptions {
 784            customer: Some(stripe_customer_id),
 785            // Sync all non-canceled subscriptions.
 786            status: None,
 787            ..Default::default()
 788        },
 789    )
 790    .await?;
 791
 792    for subscription in subscriptions.data {
 793        let subscription_id = subscription.id.clone();
 794
 795        sync_subscription(&app, &stripe_client, subscription)
 796            .await
 797            .with_context(|| {
 798                format!(
 799                    "failed to sync subscription {subscription_id} for user {}",
 800                    user.id,
 801                )
 802            })?;
 803    }
 804
 805    Ok(Json(SyncBillingSubscriptionResponse {
 806        stripe_customer_id: billing_customer.stripe_customer_id.clone(),
 807    }))
 808}
 809
 810/// The amount of time we wait in between each poll of Stripe events.
 811///
 812/// This value should strike a balance between:
 813///   1. Being short enough that we update quickly when something in Stripe changes
 814///   2. Being long enough that we don't eat into our rate limits.
 815///
 816/// As a point of reference, the Sequin folks say they have this at **500ms**:
 817///
 818/// > We poll the Stripe /events endpoint every 500ms per account
 819/// >
 820/// > — https://blog.sequinstream.com/events-not-webhooks/
 821const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
 822
 823/// The maximum number of events to return per page.
 824///
 825/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
 826///
 827/// > Limit can range between 1 and 100, and the default is 10.
 828const EVENTS_LIMIT_PER_PAGE: u64 = 100;
 829
 830/// The number of pages consisting entirely of already-processed events that we
 831/// will see before we stop retrieving events.
 832///
 833/// This is used to prevent over-fetching the Stripe events API for events we've
 834/// already seen and processed.
 835const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
 836
 837/// Polls the Stripe events API periodically to reconcile the records in our
 838/// database with the data in Stripe.
 839pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
 840    let Some(stripe_client) = app.stripe_client.clone() else {
 841        log::warn!("failed to retrieve Stripe client");
 842        return;
 843    };
 844
 845    let executor = app.executor.clone();
 846    executor.spawn_detached({
 847        let executor = executor.clone();
 848        async move {
 849            loop {
 850                poll_stripe_events(&app, &rpc_server, &stripe_client)
 851                    .await
 852                    .log_err();
 853
 854                executor.sleep(POLL_EVENTS_INTERVAL).await;
 855            }
 856        }
 857    });
 858}
 859
 860async fn poll_stripe_events(
 861    app: &Arc<AppState>,
 862    rpc_server: &Arc<Server>,
 863    stripe_client: &stripe::Client,
 864) -> anyhow::Result<()> {
 865    fn event_type_to_string(event_type: EventType) -> String {
 866        // Calling `to_string` on `stripe::EventType` members gives us a quoted string,
 867        // so we need to unquote it.
 868        event_type.to_string().trim_matches('"').to_string()
 869    }
 870
 871    let event_types = [
 872        EventType::CustomerCreated,
 873        EventType::CustomerUpdated,
 874        EventType::CustomerSubscriptionCreated,
 875        EventType::CustomerSubscriptionUpdated,
 876        EventType::CustomerSubscriptionPaused,
 877        EventType::CustomerSubscriptionResumed,
 878        EventType::CustomerSubscriptionDeleted,
 879    ]
 880    .into_iter()
 881    .map(event_type_to_string)
 882    .collect::<Vec<_>>();
 883
 884    let mut pages_of_already_processed_events = 0;
 885    let mut unprocessed_events = Vec::new();
 886
 887    log::info!(
 888        "Stripe events: starting retrieval for {}",
 889        event_types.join(", ")
 890    );
 891    let mut params = ListEvents::new();
 892    params.types = Some(event_types.clone());
 893    params.limit = Some(EVENTS_LIMIT_PER_PAGE);
 894
 895    let mut event_pages = stripe::Event::list(&stripe_client, &params)
 896        .await?
 897        .paginate(params);
 898
 899    loop {
 900        let processed_event_ids = {
 901            let event_ids = event_pages
 902                .page
 903                .data
 904                .iter()
 905                .map(|event| event.id.as_str())
 906                .collect::<Vec<_>>();
 907            app.db
 908                .get_processed_stripe_events_by_event_ids(&event_ids)
 909                .await?
 910                .into_iter()
 911                .map(|event| event.stripe_event_id)
 912                .collect::<Vec<_>>()
 913        };
 914
 915        let mut processed_events_in_page = 0;
 916        let events_in_page = event_pages.page.data.len();
 917        for event in &event_pages.page.data {
 918            if processed_event_ids.contains(&event.id.to_string()) {
 919                processed_events_in_page += 1;
 920                log::debug!("Stripe events: already processed '{}', skipping", event.id);
 921            } else {
 922                unprocessed_events.push(event.clone());
 923            }
 924        }
 925
 926        if processed_events_in_page == events_in_page {
 927            pages_of_already_processed_events += 1;
 928        }
 929
 930        if event_pages.page.has_more {
 931            if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
 932            {
 933                log::info!(
 934                    "Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
 935                );
 936                break;
 937            } else {
 938                log::info!("Stripe events: retrieving next page");
 939                event_pages = event_pages.next(&stripe_client).await?;
 940            }
 941        } else {
 942            break;
 943        }
 944    }
 945
 946    log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
 947
 948    // Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
 949    unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
 950
 951    for event in unprocessed_events {
 952        let event_id = event.id.clone();
 953        let processed_event_params = CreateProcessedStripeEventParams {
 954            stripe_event_id: event.id.to_string(),
 955            stripe_event_type: event_type_to_string(event.type_),
 956            stripe_event_created_timestamp: event.created,
 957        };
 958
 959        // If the event has happened too far in the past, we don't want to
 960        // process it and risk overwriting other more-recent updates.
 961        //
 962        // 1 day was chosen arbitrarily. This could be made longer or shorter.
 963        let one_day = Duration::from_secs(24 * 60 * 60);
 964        let a_day_ago = Utc::now() - one_day;
 965        if a_day_ago.timestamp() > event.created {
 966            log::info!(
 967                "Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
 968                event_id
 969            );
 970            app.db
 971                .create_processed_stripe_event(&processed_event_params)
 972                .await?;
 973
 974            continue;
 975        }
 976
 977        let process_result = match event.type_ {
 978            EventType::CustomerCreated | EventType::CustomerUpdated => {
 979                handle_customer_event(app, stripe_client, event).await
 980            }
 981            EventType::CustomerSubscriptionCreated
 982            | EventType::CustomerSubscriptionUpdated
 983            | EventType::CustomerSubscriptionPaused
 984            | EventType::CustomerSubscriptionResumed
 985            | EventType::CustomerSubscriptionDeleted => {
 986                handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
 987            }
 988            _ => Ok(()),
 989        };
 990
 991        if let Some(()) = process_result
 992            .with_context(|| format!("failed to process event {event_id} successfully"))
 993            .log_err()
 994        {
 995            app.db
 996                .create_processed_stripe_event(&processed_event_params)
 997                .await?;
 998        }
 999    }
1000
1001    Ok(())
1002}
1003
1004async fn handle_customer_event(
1005    app: &Arc<AppState>,
1006    _stripe_client: &stripe::Client,
1007    event: stripe::Event,
1008) -> anyhow::Result<()> {
1009    let EventObject::Customer(customer) = event.data.object else {
1010        bail!("unexpected event payload for {}", event.id);
1011    };
1012
1013    log::info!("handling Stripe {} event: {}", event.type_, event.id);
1014
1015    let Some(email) = customer.email else {
1016        log::info!("Stripe customer has no email: skipping");
1017        return Ok(());
1018    };
1019
1020    let Some(user) = app.db.get_user_by_email(&email).await? else {
1021        log::info!("no user found for email: skipping");
1022        return Ok(());
1023    };
1024
1025    if let Some(existing_customer) = app
1026        .db
1027        .get_billing_customer_by_stripe_customer_id(&customer.id)
1028        .await?
1029    {
1030        app.db
1031            .update_billing_customer(
1032                existing_customer.id,
1033                &UpdateBillingCustomerParams {
1034                    // For now we just leave the information as-is, as it is not
1035                    // likely to change.
1036                    ..Default::default()
1037                },
1038            )
1039            .await?;
1040    } else {
1041        app.db
1042            .create_billing_customer(&CreateBillingCustomerParams {
1043                user_id: user.id,
1044                stripe_customer_id: customer.id.to_string(),
1045            })
1046            .await?;
1047    }
1048
1049    Ok(())
1050}
1051
1052async fn sync_subscription(
1053    app: &Arc<AppState>,
1054    stripe_client: &stripe::Client,
1055    subscription: stripe::Subscription,
1056) -> anyhow::Result<billing_customer::Model> {
1057    let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
1058        stripe_billing
1059            .determine_subscription_kind(&subscription)
1060            .await
1061    } else {
1062        None
1063    };
1064
1065    let billing_customer =
1066        find_or_create_billing_customer(app, stripe_client, subscription.customer)
1067            .await?
1068            .ok_or_else(|| anyhow!("billing customer not found"))?;
1069
1070    if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
1071        if subscription.status == SubscriptionStatus::Trialing {
1072            let current_period_start =
1073                DateTime::from_timestamp(subscription.current_period_start, 0)
1074                    .ok_or_else(|| anyhow!("No trial subscription period start"))?;
1075
1076            app.db
1077                .update_billing_customer(
1078                    billing_customer.id,
1079                    &UpdateBillingCustomerParams {
1080                        trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
1081                        ..Default::default()
1082                    },
1083                )
1084                .await?;
1085        }
1086    }
1087
1088    let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
1089        && subscription
1090            .cancellation_details
1091            .as_ref()
1092            .and_then(|details| details.reason)
1093            .map_or(false, |reason| {
1094                reason == CancellationDetailsReason::PaymentFailed
1095            });
1096
1097    if was_canceled_due_to_payment_failure {
1098        app.db
1099            .update_billing_customer(
1100                billing_customer.id,
1101                &UpdateBillingCustomerParams {
1102                    has_overdue_invoices: ActiveValue::set(true),
1103                    ..Default::default()
1104                },
1105            )
1106            .await?;
1107    }
1108
1109    if let Some(existing_subscription) = app
1110        .db
1111        .get_billing_subscription_by_stripe_subscription_id(&subscription.id)
1112        .await?
1113    {
1114        app.db
1115            .update_billing_subscription(
1116                existing_subscription.id,
1117                &UpdateBillingSubscriptionParams {
1118                    billing_customer_id: ActiveValue::set(billing_customer.id),
1119                    kind: ActiveValue::set(subscription_kind),
1120                    stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
1121                    stripe_subscription_status: ActiveValue::set(subscription.status.into()),
1122                    stripe_cancel_at: ActiveValue::set(
1123                        subscription
1124                            .cancel_at
1125                            .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
1126                            .map(|time| time.naive_utc()),
1127                    ),
1128                    stripe_cancellation_reason: ActiveValue::set(
1129                        subscription
1130                            .cancellation_details
1131                            .and_then(|details| details.reason)
1132                            .map(|reason| reason.into()),
1133                    ),
1134                    stripe_current_period_start: ActiveValue::set(Some(
1135                        subscription.current_period_start,
1136                    )),
1137                    stripe_current_period_end: ActiveValue::set(Some(
1138                        subscription.current_period_end,
1139                    )),
1140                },
1141            )
1142            .await?;
1143    } else {
1144        if let Some(existing_subscription) = app
1145            .db
1146            .get_active_billing_subscription(billing_customer.user_id)
1147            .await?
1148        {
1149            if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
1150                && subscription_kind == Some(SubscriptionKind::ZedProTrial)
1151            {
1152                let stripe_subscription_id = existing_subscription
1153                    .stripe_subscription_id
1154                    .parse::<stripe::SubscriptionId>()
1155                    .context("failed to parse Stripe subscription ID from database")?;
1156
1157                Subscription::cancel(
1158                    &stripe_client,
1159                    &stripe_subscription_id,
1160                    stripe::CancelSubscription {
1161                        invoice_now: None,
1162                        ..Default::default()
1163                    },
1164                )
1165                .await?;
1166            } else {
1167                // If the user already has an active billing subscription, ignore the
1168                // event and return an `Ok` to signal that it was processed
1169                // successfully.
1170                //
1171                // There is the possibility that this could cause us to not create a
1172                // subscription in the following scenario:
1173                //
1174                //   1. User has an active subscription A
1175                //   2. User cancels subscription A
1176                //   3. User creates a new subscription B
1177                //   4. We process the new subscription B before the cancellation of subscription A
1178                //   5. User ends up with no subscriptions
1179                //
1180                // In theory this situation shouldn't arise as we try to process the events in the order they occur.
1181
1182                log::info!(
1183                    "user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
1184                    user_id = billing_customer.user_id,
1185                    subscription_id = subscription.id
1186                );
1187                return Ok(billing_customer);
1188            }
1189        }
1190
1191        app.db
1192            .create_billing_subscription(&CreateBillingSubscriptionParams {
1193                billing_customer_id: billing_customer.id,
1194                kind: subscription_kind,
1195                stripe_subscription_id: subscription.id.to_string(),
1196                stripe_subscription_status: subscription.status.into(),
1197                stripe_cancellation_reason: subscription
1198                    .cancellation_details
1199                    .and_then(|details| details.reason)
1200                    .map(|reason| reason.into()),
1201                stripe_current_period_start: Some(subscription.current_period_start),
1202                stripe_current_period_end: Some(subscription.current_period_end),
1203            })
1204            .await?;
1205    }
1206
1207    if let Some(stripe_billing) = app.stripe_billing.as_ref() {
1208        if subscription.status == SubscriptionStatus::Canceled
1209            || subscription.status == SubscriptionStatus::Paused
1210        {
1211            let already_has_active_billing_subscription = app
1212                .db
1213                .has_active_billing_subscription(billing_customer.user_id)
1214                .await?;
1215            if !already_has_active_billing_subscription {
1216                let stripe_customer_id = billing_customer
1217                    .stripe_customer_id
1218                    .parse::<stripe::CustomerId>()
1219                    .context("failed to parse Stripe customer ID from database")?;
1220
1221                stripe_billing
1222                    .subscribe_to_zed_free(stripe_customer_id)
1223                    .await?;
1224            }
1225        }
1226    }
1227
1228    Ok(billing_customer)
1229}
1230
1231async fn handle_customer_subscription_event(
1232    app: &Arc<AppState>,
1233    rpc_server: &Arc<Server>,
1234    stripe_client: &stripe::Client,
1235    event: stripe::Event,
1236) -> anyhow::Result<()> {
1237    let EventObject::Subscription(subscription) = event.data.object else {
1238        bail!("unexpected event payload for {}", event.id);
1239    };
1240
1241    log::info!("handling Stripe {} event: {}", event.type_, event.id);
1242
1243    let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
1244
1245    // When the user's subscription changes, push down any changes to their plan.
1246    rpc_server
1247        .update_plan_for_user(billing_customer.user_id)
1248        .await
1249        .trace_err();
1250
1251    // When the user's subscription changes, we want to refresh their LLM tokens
1252    // to either grant/revoke access.
1253    rpc_server
1254        .refresh_llm_tokens_for_user(billing_customer.user_id)
1255        .await;
1256
1257    Ok(())
1258}
1259
1260#[derive(Debug, Deserialize)]
1261struct GetMonthlySpendParams {
1262    github_user_id: i32,
1263}
1264
1265#[derive(Debug, Serialize)]
1266struct GetMonthlySpendResponse {
1267    monthly_free_tier_spend_in_cents: u32,
1268    monthly_free_tier_allowance_in_cents: u32,
1269    monthly_spend_in_cents: u32,
1270}
1271
1272async fn get_monthly_spend(
1273    Extension(app): Extension<Arc<AppState>>,
1274    Query(params): Query<GetMonthlySpendParams>,
1275) -> Result<Json<GetMonthlySpendResponse>> {
1276    let user = app
1277        .db
1278        .get_user_by_github_user_id(params.github_user_id)
1279        .await?
1280        .ok_or_else(|| anyhow!("user not found"))?;
1281
1282    let Some(llm_db) = app.llm_db.clone() else {
1283        return Err(Error::http(
1284            StatusCode::NOT_IMPLEMENTED,
1285            "LLM database not available".into(),
1286        ));
1287    };
1288
1289    let free_tier = user
1290        .custom_llm_monthly_allowance_in_cents
1291        .map(|allowance| Cents(allowance as u32))
1292        .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT);
1293
1294    let spending_for_month = llm_db
1295        .get_user_spending_for_month(user.id, Utc::now())
1296        .await?;
1297
1298    let free_tier_spend = Cents::min(spending_for_month, free_tier);
1299    let monthly_spend = spending_for_month.saturating_sub(free_tier);
1300
1301    Ok(Json(GetMonthlySpendResponse {
1302        monthly_free_tier_spend_in_cents: free_tier_spend.0,
1303        monthly_free_tier_allowance_in_cents: free_tier.0,
1304        monthly_spend_in_cents: monthly_spend.0,
1305    }))
1306}
1307
1308#[derive(Debug, Deserialize)]
1309struct GetCurrentUsageParams {
1310    github_user_id: i32,
1311}
1312
1313#[derive(Debug, Serialize)]
1314struct UsageCounts {
1315    pub used: i32,
1316    pub limit: Option<i32>,
1317    pub remaining: Option<i32>,
1318}
1319
1320#[derive(Debug, Serialize)]
1321struct ModelRequestUsage {
1322    pub model: String,
1323    pub mode: CompletionMode,
1324    pub requests: i32,
1325}
1326
1327#[derive(Debug, Serialize)]
1328struct CurrentUsage {
1329    pub model_requests: UsageCounts,
1330    pub model_request_usage: Vec<ModelRequestUsage>,
1331    pub edit_predictions: UsageCounts,
1332}
1333
1334#[derive(Debug, Default, Serialize)]
1335struct GetCurrentUsageResponse {
1336    pub plan: String,
1337    pub current_usage: Option<CurrentUsage>,
1338}
1339
1340async fn get_current_usage(
1341    Extension(app): Extension<Arc<AppState>>,
1342    Query(params): Query<GetCurrentUsageParams>,
1343) -> Result<Json<GetCurrentUsageResponse>> {
1344    let user = app
1345        .db
1346        .get_user_by_github_user_id(params.github_user_id)
1347        .await?
1348        .ok_or_else(|| anyhow!("user not found"))?;
1349
1350    let feature_flags = app.db.get_user_flags(user.id).await?;
1351    let has_extended_trial = feature_flags
1352        .iter()
1353        .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
1354
1355    let Some(llm_db) = app.llm_db.clone() else {
1356        return Err(Error::http(
1357            StatusCode::NOT_IMPLEMENTED,
1358            "LLM database not available".into(),
1359        ));
1360    };
1361
1362    let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
1363        return Ok(Json(GetCurrentUsageResponse::default()));
1364    };
1365
1366    let subscription_period = maybe!({
1367        let period_start_at = subscription.current_period_start_at()?;
1368        let period_end_at = subscription.current_period_end_at()?;
1369
1370        Some((period_start_at, period_end_at))
1371    });
1372
1373    let Some((period_start_at, period_end_at)) = subscription_period else {
1374        return Ok(Json(GetCurrentUsageResponse::default()));
1375    };
1376
1377    let usage = llm_db
1378        .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
1379        .await?;
1380
1381    let plan = usage
1382        .as_ref()
1383        .map(|usage| usage.plan.into())
1384        .unwrap_or_else(|| {
1385            subscription
1386                .kind
1387                .map(Into::into)
1388                .unwrap_or(zed_llm_client::Plan::ZedFree)
1389        });
1390
1391    let model_requests_limit = match plan.model_requests_limit() {
1392        zed_llm_client::UsageLimit::Limited(limit) => {
1393            let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial {
1394                1_000
1395            } else {
1396                limit
1397            };
1398
1399            Some(limit)
1400        }
1401        zed_llm_client::UsageLimit::Unlimited => None,
1402    };
1403
1404    let edit_predictions_limit = match plan.edit_predictions_limit() {
1405        zed_llm_client::UsageLimit::Limited(limit) => Some(limit),
1406        zed_llm_client::UsageLimit::Unlimited => None,
1407    };
1408
1409    let Some(usage) = usage else {
1410        return Ok(Json(GetCurrentUsageResponse {
1411            plan: plan.as_str().to_string(),
1412            current_usage: Some(CurrentUsage {
1413                model_requests: UsageCounts {
1414                    used: 0,
1415                    limit: model_requests_limit,
1416                    remaining: model_requests_limit,
1417                },
1418                model_request_usage: Vec::new(),
1419                edit_predictions: UsageCounts {
1420                    used: 0,
1421                    limit: edit_predictions_limit,
1422                    remaining: edit_predictions_limit,
1423                },
1424            }),
1425        }));
1426    };
1427
1428    let subscription_usage_meters = llm_db
1429        .get_current_subscription_usage_meters_for_user(user.id, Utc::now())
1430        .await?;
1431
1432    let model_request_usage = subscription_usage_meters
1433        .into_iter()
1434        .filter_map(|(usage_meter, _usage)| {
1435            let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
1436
1437            Some(ModelRequestUsage {
1438                model: model.name.clone(),
1439                mode: usage_meter.mode,
1440                requests: usage_meter.requests,
1441            })
1442        })
1443        .collect::<Vec<_>>();
1444
1445    Ok(Json(GetCurrentUsageResponse {
1446        plan: plan.as_str().to_string(),
1447        current_usage: Some(CurrentUsage {
1448            model_requests: UsageCounts {
1449                used: usage.model_requests,
1450                limit: model_requests_limit,
1451                remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
1452            },
1453            model_request_usage,
1454            edit_predictions: UsageCounts {
1455                used: usage.edit_predictions,
1456                limit: edit_predictions_limit,
1457                remaining: edit_predictions_limit
1458                    .map(|limit| (limit - usage.edit_predictions).max(0)),
1459            },
1460        }),
1461    }))
1462}
1463
1464impl From<SubscriptionStatus> for StripeSubscriptionStatus {
1465    fn from(value: SubscriptionStatus) -> Self {
1466        match value {
1467            SubscriptionStatus::Incomplete => Self::Incomplete,
1468            SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
1469            SubscriptionStatus::Trialing => Self::Trialing,
1470            SubscriptionStatus::Active => Self::Active,
1471            SubscriptionStatus::PastDue => Self::PastDue,
1472            SubscriptionStatus::Canceled => Self::Canceled,
1473            SubscriptionStatus::Unpaid => Self::Unpaid,
1474            SubscriptionStatus::Paused => Self::Paused,
1475        }
1476    }
1477}
1478
1479impl From<CancellationDetailsReason> for StripeCancellationReason {
1480    fn from(value: CancellationDetailsReason) -> Self {
1481        match value {
1482            CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
1483            CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
1484            CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
1485        }
1486    }
1487}
1488
1489/// Finds or creates a billing customer using the provided customer.
1490async fn find_or_create_billing_customer(
1491    app: &Arc<AppState>,
1492    stripe_client: &stripe::Client,
1493    customer_or_id: Expandable<Customer>,
1494) -> anyhow::Result<Option<billing_customer::Model>> {
1495    let customer_id = match &customer_or_id {
1496        Expandable::Id(id) => id,
1497        Expandable::Object(customer) => customer.id.as_ref(),
1498    };
1499
1500    // If we already have a billing customer record associated with the Stripe customer,
1501    // there's nothing more we need to do.
1502    if let Some(billing_customer) = app
1503        .db
1504        .get_billing_customer_by_stripe_customer_id(customer_id)
1505        .await?
1506    {
1507        return Ok(Some(billing_customer));
1508    }
1509
1510    // If all we have is a customer ID, resolve it to a full customer record by
1511    // hitting the Stripe API.
1512    let customer = match customer_or_id {
1513        Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
1514        Expandable::Object(customer) => *customer,
1515    };
1516
1517    let Some(email) = customer.email else {
1518        return Ok(None);
1519    };
1520
1521    let Some(user) = app.db.get_user_by_email(&email).await? else {
1522        return Ok(None);
1523    };
1524
1525    let billing_customer = app
1526        .db
1527        .create_billing_customer(&CreateBillingCustomerParams {
1528            user_id: user.id,
1529            stripe_customer_id: customer.id.to_string(),
1530        })
1531        .await?;
1532
1533    Ok(Some(billing_customer))
1534}
1535
1536const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
1537
1538pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
1539    let Some(stripe_billing) = app.stripe_billing.clone() else {
1540        log::warn!("failed to retrieve Stripe billing object");
1541        return;
1542    };
1543    let Some(llm_db) = app.llm_db.clone() else {
1544        log::warn!("failed to retrieve LLM database");
1545        return;
1546    };
1547
1548    let executor = app.executor.clone();
1549    executor.spawn_detached({
1550        let executor = executor.clone();
1551        async move {
1552            loop {
1553                sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
1554                    .await
1555                    .context("failed to sync LLM request usage to Stripe")
1556                    .trace_err();
1557                executor
1558                    .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
1559                    .await;
1560            }
1561        }
1562    });
1563}
1564
1565async fn sync_model_request_usage_with_stripe(
1566    app: &Arc<AppState>,
1567    llm_db: &Arc<LlmDatabase>,
1568    stripe_billing: &Arc<StripeBilling>,
1569) -> anyhow::Result<()> {
1570    let staff_users = app.db.get_staff_users().await?;
1571    let staff_user_ids = staff_users
1572        .iter()
1573        .map(|user| user.id)
1574        .collect::<HashSet<UserId>>();
1575
1576    let usage_meters = llm_db
1577        .get_current_subscription_usage_meters(Utc::now())
1578        .await?;
1579    let usage_meters = usage_meters
1580        .into_iter()
1581        .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
1582        .collect::<Vec<_>>();
1583    let user_ids = usage_meters
1584        .iter()
1585        .map(|(_, usage)| usage.user_id)
1586        .collect::<HashSet<UserId>>();
1587    let billing_subscriptions = app
1588        .db
1589        .get_active_zed_pro_billing_subscriptions(user_ids)
1590        .await?;
1591
1592    let claude_3_5_sonnet = stripe_billing
1593        .find_price_by_lookup_key("claude-3-5-sonnet-requests")
1594        .await?;
1595    let claude_3_7_sonnet = stripe_billing
1596        .find_price_by_lookup_key("claude-3-7-sonnet-requests")
1597        .await?;
1598    let claude_3_7_sonnet_max = stripe_billing
1599        .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
1600        .await?;
1601
1602    for (usage_meter, usage) in usage_meters {
1603        maybe!(async {
1604            let Some((billing_customer, billing_subscription)) =
1605                billing_subscriptions.get(&usage.user_id)
1606            else {
1607                bail!(
1608                    "Attempted to sync usage meter for user who is not a Stripe customer: {}",
1609                    usage.user_id
1610                );
1611            };
1612
1613            let stripe_customer_id = billing_customer
1614                .stripe_customer_id
1615                .parse::<stripe::CustomerId>()
1616                .context("failed to parse Stripe customer ID from database")?;
1617            let stripe_subscription_id = billing_subscription
1618                .stripe_subscription_id
1619                .parse::<stripe::SubscriptionId>()
1620                .context("failed to parse Stripe subscription ID from database")?;
1621
1622            let model = llm_db.model_by_id(usage_meter.model_id)?;
1623
1624            let (price, meter_event_name) = match model.name.as_str() {
1625                "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
1626                "claude-3-7-sonnet" => match usage_meter.mode {
1627                    CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
1628                    CompletionMode::Max => {
1629                        (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
1630                    }
1631                },
1632                model_name => {
1633                    bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
1634                }
1635            };
1636
1637            stripe_billing
1638                .subscribe_to_price(&stripe_subscription_id, price)
1639                .await?;
1640            stripe_billing
1641                .bill_model_request_usage(
1642                    &stripe_customer_id,
1643                    meter_event_name,
1644                    usage_meter.requests,
1645                )
1646                .await?;
1647
1648            Ok(())
1649        })
1650        .await
1651        .log_err();
1652    }
1653
1654    Ok(())
1655}