billing.rs

 1use std::str::FromStr;
 2use std::sync::Arc;
 3
 4use anyhow::anyhow;
 5use axum::{extract, routing::post, Extension, Json, Router};
 6use collections::HashSet;
 7use reqwest::StatusCode;
 8use serde::{Deserialize, Serialize};
 9use stripe::{CheckoutSession, CreateCheckoutSession, CreateCheckoutSessionLineItems, CustomerId};
10
11use crate::{AppState, Error, Result};
12
13pub fn router() -> Router {
14    Router::new().route("/billing/subscriptions", post(create_billing_subscription))
15}
16
17#[derive(Debug, Deserialize)]
18struct CreateBillingSubscriptionBody {
19    github_user_id: i32,
20}
21
22#[derive(Debug, Serialize)]
23struct CreateBillingSubscriptionResponse {
24    checkout_session_url: String,
25}
26
27/// Initiates a Stripe Checkout session for creating a billing subscription.
28async fn create_billing_subscription(
29    Extension(app): Extension<Arc<AppState>>,
30    extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
31) -> Result<Json<CreateBillingSubscriptionResponse>> {
32    let user = app
33        .db
34        .get_user_by_github_user_id(body.github_user_id)
35        .await?
36        .ok_or_else(|| anyhow!("user not found"))?;
37
38    let Some((stripe_client, stripe_price_id)) = app
39        .stripe_client
40        .clone()
41        .zip(app.config.stripe_price_id.clone())
42    else {
43        log::error!("failed to retrieve Stripe client or price ID");
44        Err(Error::Http(
45            StatusCode::NOT_IMPLEMENTED,
46            "not supported".into(),
47        ))?
48    };
49
50    let existing_customer_id = {
51        let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?;
52        let distinct_customer_ids = existing_subscriptions
53            .iter()
54            .map(|subscription| subscription.stripe_customer_id.as_str())
55            .collect::<HashSet<_>>();
56        // Sanity: Make sure we can determine a single Stripe customer ID for the user.
57        if distinct_customer_ids.len() > 1 {
58            Err(anyhow!("user has multiple existing customer IDs"))?;
59        }
60
61        distinct_customer_ids
62            .into_iter()
63            .next()
64            .map(|id| CustomerId::from_str(id).map_err(|err| anyhow!(err)))
65            .transpose()
66    }?;
67
68    let checkout_session = {
69        let mut params = CreateCheckoutSession::new();
70        params.mode = Some(stripe::CheckoutSessionMode::Subscription);
71        params.customer = existing_customer_id;
72        params.client_reference_id = Some(user.github_login.as_str());
73        params.line_items = Some(vec![CreateCheckoutSessionLineItems {
74            price: Some(stripe_price_id.to_string()),
75            quantity: Some(1),
76            ..Default::default()
77        }]);
78        params.success_url = Some("https://zed.dev/billing/success");
79
80        CheckoutSession::create(&stripe_client, params).await?
81    };
82
83    Ok(Json(CreateBillingSubscriptionResponse {
84        checkout_session_url: checkout_session
85            .url
86            .ok_or_else(|| anyhow!("no checkout session URL"))?,
87    }))
88}