billing.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Context};
  5use axum::{extract, routing::post, Extension, Json, Router};
  6use reqwest::StatusCode;
  7use serde::{Deserialize, Serialize};
  8use stripe::{
  9    BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
 10    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
 11    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
 12    CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
 13    CreateCustomer, Customer, CustomerId,
 14};
 15
 16use crate::db::BillingSubscriptionId;
 17use crate::{AppState, Error, Result};
 18
 19pub fn router() -> Router {
 20    Router::new()
 21        .route("/billing/subscriptions", post(create_billing_subscription))
 22        .route(
 23            "/billing/subscriptions/manage",
 24            post(manage_billing_subscription),
 25        )
 26}
 27
 28#[derive(Debug, Deserialize)]
 29struct CreateBillingSubscriptionBody {
 30    github_user_id: i32,
 31}
 32
 33#[derive(Debug, Serialize)]
 34struct CreateBillingSubscriptionResponse {
 35    checkout_session_url: String,
 36}
 37
 38/// Initiates a Stripe Checkout session for creating a billing subscription.
 39async fn create_billing_subscription(
 40    Extension(app): Extension<Arc<AppState>>,
 41    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 42) -> Result<Json<CreateBillingSubscriptionResponse>> {
 43    let user = app
 44        .db
 45        .get_user_by_github_user_id(body.github_user_id)
 46        .await?
 47        .ok_or_else(|| anyhow!("user not found"))?;
 48
 49    let Some((stripe_client, stripe_price_id)) = app
 50        .stripe_client
 51        .clone()
 52        .zip(app.config.stripe_price_id.clone())
 53    else {
 54        log::error!("failed to retrieve Stripe client or price ID");
 55        Err(Error::Http(
 56            StatusCode::NOT_IMPLEMENTED,
 57            "not supported".into(),
 58        ))?
 59    };
 60
 61    let customer_id =
 62        if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
 63            CustomerId::from_str(&existing_customer.stripe_customer_id)
 64                .context("failed to parse customer ID")?
 65        } else {
 66            let customer = Customer::create(
 67                &stripe_client,
 68                CreateCustomer {
 69                    email: user.email_address.as_deref(),
 70                    ..Default::default()
 71                },
 72            )
 73            .await?;
 74
 75            customer.id
 76        };
 77
 78    let checkout_session = {
 79        let mut params = CreateCheckoutSession::new();
 80        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
 81        params.customer = Some(customer_id);
 82        params.client_reference_id = Some(user.github_login.as_str());
 83        params.line_items = Some(vec![CreateCheckoutSessionLineItems {
 84            price: Some(stripe_price_id.to_string()),
 85            quantity: Some(1),
 86            ..Default::default()
 87        }]);
 88        params.success_url = Some("https://zed.dev/billing/success");
 89
 90        CheckoutSession::create(&stripe_client, params).await?
 91    };
 92
 93    Ok(Json(CreateBillingSubscriptionResponse {
 94        checkout_session_url: checkout_session
 95            .url
 96            .ok_or_else(|| anyhow!("no checkout session URL"))?,
 97    }))
 98}
 99
100#[derive(Debug, Deserialize)]
101#[serde(rename_all = "snake_case")]
102enum ManageSubscriptionIntent {
103    /// The user intends to cancel their subscription.
104    Cancel,
105}
106
107#[derive(Debug, Deserialize)]
108struct ManageBillingSubscriptionBody {
109    github_user_id: i32,
110    intent: ManageSubscriptionIntent,
111    /// The ID of the subscription to manage.
112    ///
113    /// If not provided, we will try to use the active subscription (if there is only one).
114    subscription_id: Option<BillingSubscriptionId>,
115}
116
117#[derive(Debug, Serialize)]
118struct ManageBillingSubscriptionResponse {
119    billing_portal_session_url: String,
120}
121
122/// Initiates a Stripe customer portal session for managing a billing subscription.
123async fn manage_billing_subscription(
124    Extension(app): Extension<Arc<AppState>>,
125    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
126) -> Result<Json<ManageBillingSubscriptionResponse>> {
127    let user = app
128        .db
129        .get_user_by_github_user_id(body.github_user_id)
130        .await?
131        .ok_or_else(|| anyhow!("user not found"))?;
132
133    let Some(stripe_client) = app.stripe_client.clone() else {
134        log::error!("failed to retrieve Stripe client");
135        Err(Error::Http(
136            StatusCode::NOT_IMPLEMENTED,
137            "not supported".into(),
138        ))?
139    };
140
141    let customer = app
142        .db
143        .get_billing_customer_by_user_id(user.id)
144        .await?
145        .ok_or_else(|| anyhow!("billing customer not found"))?;
146    let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
147        .context("failed to parse customer ID")?;
148
149    let subscription = if let Some(subscription_id) = body.subscription_id {
150        app.db
151            .get_billing_subscription_by_id(subscription_id)
152            .await?
153            .ok_or_else(|| anyhow!("subscription not found"))?
154    } else {
155        // If no subscription ID was provided, try to find the only active subscription ID.
156        let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
157        if subscriptions.len() > 1 {
158            Err(anyhow!("user has multiple active subscriptions"))?;
159        }
160
161        subscriptions
162            .into_iter()
163            .next()
164            .ok_or_else(|| anyhow!("user has no active subscriptions"))?
165    };
166
167    let flow = match body.intent {
168        ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
169            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
170            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
171                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
172                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
173                    return_url: "https://zed.dev/billing".into(),
174                }),
175                ..Default::default()
176            }),
177            subscription_cancel: Some(
178                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
179                    subscription: subscription.stripe_subscription_id,
180                    retention: None,
181                },
182            ),
183            ..Default::default()
184        },
185    };
186
187    let mut params = CreateBillingPortalSession::new(customer_id);
188    params.flow_data = Some(flow);
189    params.return_url = Some("https://zed.dev/billing");
190
191    let session = BillingPortalSession::create(&stripe_client, params).await?;
192
193    Ok(Json(ManageBillingSubscriptionResponse {
194        billing_portal_session_url: session.url,
195    }))
196}