1use std::str::FromStr;
2use std::sync::Arc;
3
4use anyhow::anyhow;
5use axum::{extract, routing::post, Extension, Json, Router};
6use collections::HashSet;
7use reqwest::StatusCode;
8use serde::{Deserialize, Serialize};
9use stripe::{CheckoutSession, CreateCheckoutSession, CreateCheckoutSessionLineItems, CustomerId};
10
11use crate::{AppState, Error, Result};
12
13pub fn router() -> Router {
14 Router::new().route("/billing/subscriptions", post(create_billing_subscription))
15}
16
17#[derive(Debug, Deserialize)]
18struct CreateBillingSubscriptionBody {
19 github_user_id: i32,
20}
21
22#[derive(Debug, Serialize)]
23struct CreateBillingSubscriptionResponse {
24 checkout_session_url: String,
25}
26
27/// Initiates a Stripe Checkout session for creating a billing subscription.
28async fn create_billing_subscription(
29 Extension(app): Extension<Arc<AppState>>,
30 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
31) -> Result<Json<CreateBillingSubscriptionResponse>> {
32 let user = app
33 .db
34 .get_user_by_github_user_id(body.github_user_id)
35 .await?
36 .ok_or_else(|| anyhow!("user not found"))?;
37
38 let Some((stripe_client, stripe_price_id)) = app
39 .stripe_client
40 .clone()
41 .zip(app.config.stripe_price_id.clone())
42 else {
43 log::error!("failed to retrieve Stripe client or price ID");
44 Err(Error::Http(
45 StatusCode::NOT_IMPLEMENTED,
46 "not supported".into(),
47 ))?
48 };
49
50 let existing_customer_id = {
51 let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?;
52 let distinct_customer_ids = existing_subscriptions
53 .iter()
54 .map(|subscription| subscription.stripe_customer_id.as_str())
55 .collect::<HashSet<_>>();
56 // Sanity: Make sure we can determine a single Stripe customer ID for the user.
57 if distinct_customer_ids.len() > 1 {
58 Err(anyhow!("user has multiple existing customer IDs"))?;
59 }
60
61 distinct_customer_ids
62 .into_iter()
63 .next()
64 .map(|id| CustomerId::from_str(id).map_err(|err| anyhow!(err)))
65 .transpose()
66 }?;
67
68 let checkout_session = {
69 let mut params = CreateCheckoutSession::new();
70 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
71 params.customer = existing_customer_id;
72 params.client_reference_id = Some(user.github_login.as_str());
73 params.line_items = Some(vec![CreateCheckoutSessionLineItems {
74 price: Some(stripe_price_id.to_string()),
75 quantity: Some(1),
76 ..Default::default()
77 }]);
78 params.success_url = Some("https://zed.dev/billing/success");
79
80 CheckoutSession::create(&stripe_client, params).await?
81 };
82
83 Ok(Json(CreateBillingSubscriptionResponse {
84 checkout_session_url: checkout_session
85 .url
86 .ok_or_else(|| anyhow!("no checkout session URL"))?,
87 }))
88}