1use std::str::FromStr;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Context};
5use axum::{extract, routing::post, Extension, Json, Router};
6use collections::HashSet;
7use reqwest::StatusCode;
8use serde::{Deserialize, Serialize};
9use stripe::{
10 BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
11 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
12 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
13 CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
14 CustomerId,
15};
16
17use crate::db::BillingSubscriptionId;
18use crate::{AppState, Error, Result};
19
20pub fn router() -> Router {
21 Router::new()
22 .route("/billing/subscriptions", post(create_billing_subscription))
23 .route(
24 "/billing/subscriptions/manage",
25 post(manage_billing_subscription),
26 )
27}
28
29#[derive(Debug, Deserialize)]
30struct CreateBillingSubscriptionBody {
31 github_user_id: i32,
32}
33
34#[derive(Debug, Serialize)]
35struct CreateBillingSubscriptionResponse {
36 checkout_session_url: String,
37}
38
39/// Initiates a Stripe Checkout session for creating a billing subscription.
40async fn create_billing_subscription(
41 Extension(app): Extension<Arc<AppState>>,
42 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
43) -> Result<Json<CreateBillingSubscriptionResponse>> {
44 let user = app
45 .db
46 .get_user_by_github_user_id(body.github_user_id)
47 .await?
48 .ok_or_else(|| anyhow!("user not found"))?;
49
50 let Some((stripe_client, stripe_price_id)) = app
51 .stripe_client
52 .clone()
53 .zip(app.config.stripe_price_id.clone())
54 else {
55 log::error!("failed to retrieve Stripe client or price ID");
56 Err(Error::Http(
57 StatusCode::NOT_IMPLEMENTED,
58 "not supported".into(),
59 ))?
60 };
61
62 let existing_customer_id = {
63 let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?;
64 let distinct_customer_ids = existing_subscriptions
65 .iter()
66 .map(|subscription| subscription.stripe_customer_id.as_str())
67 .collect::<HashSet<_>>();
68 // Sanity: Make sure we can determine a single Stripe customer ID for the user.
69 if distinct_customer_ids.len() > 1 {
70 Err(anyhow!("user has multiple existing customer IDs"))?;
71 }
72
73 distinct_customer_ids
74 .into_iter()
75 .next()
76 .map(|id| CustomerId::from_str(id).context("failed to parse customer ID"))
77 .transpose()
78 }?;
79
80 let checkout_session = {
81 let mut params = CreateCheckoutSession::new();
82 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
83 params.customer = existing_customer_id;
84 params.client_reference_id = Some(user.github_login.as_str());
85 params.line_items = Some(vec![CreateCheckoutSessionLineItems {
86 price: Some(stripe_price_id.to_string()),
87 quantity: Some(1),
88 ..Default::default()
89 }]);
90 params.success_url = Some("https://zed.dev/billing/success");
91
92 CheckoutSession::create(&stripe_client, params).await?
93 };
94
95 Ok(Json(CreateBillingSubscriptionResponse {
96 checkout_session_url: checkout_session
97 .url
98 .ok_or_else(|| anyhow!("no checkout session URL"))?,
99 }))
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(rename_all = "snake_case")]
104enum ManageSubscriptionIntent {
105 /// The user intends to cancel their subscription.
106 Cancel,
107}
108
109#[derive(Debug, Deserialize)]
110struct ManageBillingSubscriptionBody {
111 github_user_id: i32,
112 intent: ManageSubscriptionIntent,
113 /// The ID of the subscription to manage.
114 ///
115 /// If not provided, we will try to use the active subscription (if there is only one).
116 subscription_id: Option<BillingSubscriptionId>,
117}
118
119#[derive(Debug, Serialize)]
120struct ManageBillingSubscriptionResponse {
121 billing_portal_session_url: String,
122}
123
124/// Initiates a Stripe customer portal session for managing a billing subscription.
125async fn manage_billing_subscription(
126 Extension(app): Extension<Arc<AppState>>,
127 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
128) -> Result<Json<ManageBillingSubscriptionResponse>> {
129 let user = app
130 .db
131 .get_user_by_github_user_id(body.github_user_id)
132 .await?
133 .ok_or_else(|| anyhow!("user not found"))?;
134
135 let Some(stripe_client) = app.stripe_client.clone() else {
136 log::error!("failed to retrieve Stripe client");
137 Err(Error::Http(
138 StatusCode::NOT_IMPLEMENTED,
139 "not supported".into(),
140 ))?
141 };
142
143 let subscription = if let Some(subscription_id) = body.subscription_id {
144 app.db
145 .get_billing_subscription_by_id(subscription_id)
146 .await?
147 .ok_or_else(|| anyhow!("subscription not found"))?
148 } else {
149 // If no subscription ID was provided, try to find the only active subscription ID.
150 let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
151 if subscriptions.len() > 1 {
152 Err(anyhow!("user has multiple active subscriptions"))?;
153 }
154
155 subscriptions
156 .into_iter()
157 .next()
158 .ok_or_else(|| anyhow!("user has no active subscriptions"))?
159 };
160
161 let customer_id = CustomerId::from_str(&subscription.stripe_customer_id)
162 .context("failed to parse customer ID")?;
163
164 let flow = match body.intent {
165 ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
166 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
167 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
168 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
169 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
170 return_url: "https://zed.dev/billing".into(),
171 }),
172 ..Default::default()
173 }),
174 subscription_cancel: Some(
175 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
176 subscription: subscription.stripe_subscription_id,
177 retention: None,
178 },
179 ),
180 ..Default::default()
181 },
182 };
183
184 let mut params = CreateBillingPortalSession::new(customer_id);
185 params.flow_data = Some(flow);
186 params.return_url = Some("https://zed.dev/billing");
187
188 let session = BillingPortalSession::create(&stripe_client, params).await?;
189
190 Ok(Json(ManageBillingSubscriptionResponse {
191 billing_portal_session_url: session.url,
192 }))
193}