billing.rs

  1use std::str::FromStr;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Context};
  5use axum::{extract, routing::post, Extension, Json, Router};
  6use collections::HashSet;
  7use reqwest::StatusCode;
  8use serde::{Deserialize, Serialize};
  9use stripe::{
 10    BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
 11    CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
 12    CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
 13    CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
 14    CustomerId,
 15};
 16
 17use crate::db::BillingSubscriptionId;
 18use crate::{AppState, Error, Result};
 19
 20pub fn router() -> Router {
 21    Router::new()
 22        .route("/billing/subscriptions", post(create_billing_subscription))
 23        .route(
 24            "/billing/subscriptions/manage",
 25            post(manage_billing_subscription),
 26        )
 27}
 28
 29#[derive(Debug, Deserialize)]
 30struct CreateBillingSubscriptionBody {
 31    github_user_id: i32,
 32}
 33
 34#[derive(Debug, Serialize)]
 35struct CreateBillingSubscriptionResponse {
 36    checkout_session_url: String,
 37}
 38
 39/// Initiates a Stripe Checkout session for creating a billing subscription.
 40async fn create_billing_subscription(
 41    Extension(app): Extension<Arc<AppState>>,
 42    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
 43) -> Result<Json<CreateBillingSubscriptionResponse>> {
 44    let user = app
 45        .db
 46        .get_user_by_github_user_id(body.github_user_id)
 47        .await?
 48        .ok_or_else(|| anyhow!("user not found"))?;
 49
 50    let Some((stripe_client, stripe_price_id)) = app
 51        .stripe_client
 52        .clone()
 53        .zip(app.config.stripe_price_id.clone())
 54    else {
 55        log::error!("failed to retrieve Stripe client or price ID");
 56        Err(Error::Http(
 57            StatusCode::NOT_IMPLEMENTED,
 58            "not supported".into(),
 59        ))?
 60    };
 61
 62    let existing_customer_id = {
 63        let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?;
 64        let distinct_customer_ids = existing_subscriptions
 65            .iter()
 66            .map(|subscription| subscription.stripe_customer_id.as_str())
 67            .collect::<HashSet<_>>();
 68        // Sanity: Make sure we can determine a single Stripe customer ID for the user.
 69        if distinct_customer_ids.len() > 1 {
 70            Err(anyhow!("user has multiple existing customer IDs"))?;
 71        }
 72
 73        distinct_customer_ids
 74            .into_iter()
 75            .next()
 76            .map(|id| CustomerId::from_str(id).context("failed to parse customer ID"))
 77            .transpose()
 78    }?;
 79
 80    let checkout_session = {
 81        let mut params = CreateCheckoutSession::new();
 82        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
 83        params.customer = existing_customer_id;
 84        params.client_reference_id = Some(user.github_login.as_str());
 85        params.line_items = Some(vec![CreateCheckoutSessionLineItems {
 86            price: Some(stripe_price_id.to_string()),
 87            quantity: Some(1),
 88            ..Default::default()
 89        }]);
 90        params.success_url = Some("https://zed.dev/billing/success");
 91
 92        CheckoutSession::create(&stripe_client, params).await?
 93    };
 94
 95    Ok(Json(CreateBillingSubscriptionResponse {
 96        checkout_session_url: checkout_session
 97            .url
 98            .ok_or_else(|| anyhow!("no checkout session URL"))?,
 99    }))
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(rename_all = "snake_case")]
104enum ManageSubscriptionIntent {
105    /// The user intends to cancel their subscription.
106    Cancel,
107}
108
109#[derive(Debug, Deserialize)]
110struct ManageBillingSubscriptionBody {
111    github_user_id: i32,
112    intent: ManageSubscriptionIntent,
113    /// The ID of the subscription to manage.
114    ///
115    /// If not provided, we will try to use the active subscription (if there is only one).
116    subscription_id: Option<BillingSubscriptionId>,
117}
118
119#[derive(Debug, Serialize)]
120struct ManageBillingSubscriptionResponse {
121    billing_portal_session_url: String,
122}
123
124/// Initiates a Stripe customer portal session for managing a billing subscription.
125async fn manage_billing_subscription(
126    Extension(app): Extension<Arc<AppState>>,
127    extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
128) -> Result<Json<ManageBillingSubscriptionResponse>> {
129    let user = app
130        .db
131        .get_user_by_github_user_id(body.github_user_id)
132        .await?
133        .ok_or_else(|| anyhow!("user not found"))?;
134
135    let Some(stripe_client) = app.stripe_client.clone() else {
136        log::error!("failed to retrieve Stripe client");
137        Err(Error::Http(
138            StatusCode::NOT_IMPLEMENTED,
139            "not supported".into(),
140        ))?
141    };
142
143    let subscription = if let Some(subscription_id) = body.subscription_id {
144        app.db
145            .get_billing_subscription_by_id(subscription_id)
146            .await?
147            .ok_or_else(|| anyhow!("subscription not found"))?
148    } else {
149        // If no subscription ID was provided, try to find the only active subscription ID.
150        let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
151        if subscriptions.len() > 1 {
152            Err(anyhow!("user has multiple active subscriptions"))?;
153        }
154
155        subscriptions
156            .into_iter()
157            .next()
158            .ok_or_else(|| anyhow!("user has no active subscriptions"))?
159    };
160
161    let customer_id = CustomerId::from_str(&subscription.stripe_customer_id)
162        .context("failed to parse customer ID")?;
163
164    let flow = match body.intent {
165        ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
166            type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
167            after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
168                type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
169                redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
170                    return_url: "https://zed.dev/billing".into(),
171                }),
172                ..Default::default()
173            }),
174            subscription_cancel: Some(
175                stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
176                    subscription: subscription.stripe_subscription_id,
177                    retention: None,
178                },
179            ),
180            ..Default::default()
181        },
182    };
183
184    let mut params = CreateBillingPortalSession::new(customer_id);
185    params.flow_data = Some(flow);
186    params.return_url = Some("https://zed.dev/billing");
187
188    let session = BillingPortalSession::create(&stripe_client, params).await?;
189
190    Ok(Json(ManageBillingSubscriptionResponse {
191        billing_portal_session_url: session.url,
192    }))
193}