1use std::str::FromStr;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Context};
5use axum::{extract, routing::post, Extension, Json, Router};
6use reqwest::StatusCode;
7use serde::{Deserialize, Serialize};
8use stripe::{
9 BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
10 CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
11 CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
12 CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
13 CreateCustomer, Customer, CustomerId,
14};
15
16use crate::db::BillingSubscriptionId;
17use crate::{AppState, Error, Result};
18
19pub fn router() -> Router {
20 Router::new()
21 .route("/billing/subscriptions", post(create_billing_subscription))
22 .route(
23 "/billing/subscriptions/manage",
24 post(manage_billing_subscription),
25 )
26}
27
28#[derive(Debug, Deserialize)]
29struct CreateBillingSubscriptionBody {
30 github_user_id: i32,
31}
32
33#[derive(Debug, Serialize)]
34struct CreateBillingSubscriptionResponse {
35 checkout_session_url: String,
36}
37
38/// Initiates a Stripe Checkout session for creating a billing subscription.
39async fn create_billing_subscription(
40 Extension(app): Extension<Arc<AppState>>,
41 extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
42) -> Result<Json<CreateBillingSubscriptionResponse>> {
43 let user = app
44 .db
45 .get_user_by_github_user_id(body.github_user_id)
46 .await?
47 .ok_or_else(|| anyhow!("user not found"))?;
48
49 let Some((stripe_client, stripe_price_id)) = app
50 .stripe_client
51 .clone()
52 .zip(app.config.stripe_price_id.clone())
53 else {
54 log::error!("failed to retrieve Stripe client or price ID");
55 Err(Error::Http(
56 StatusCode::NOT_IMPLEMENTED,
57 "not supported".into(),
58 ))?
59 };
60
61 let customer_id =
62 if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
63 CustomerId::from_str(&existing_customer.stripe_customer_id)
64 .context("failed to parse customer ID")?
65 } else {
66 let customer = Customer::create(
67 &stripe_client,
68 CreateCustomer {
69 email: user.email_address.as_deref(),
70 ..Default::default()
71 },
72 )
73 .await?;
74
75 customer.id
76 };
77
78 let checkout_session = {
79 let mut params = CreateCheckoutSession::new();
80 params.mode = Some(stripe::CheckoutSessionMode::Subscription);
81 params.customer = Some(customer_id);
82 params.client_reference_id = Some(user.github_login.as_str());
83 params.line_items = Some(vec![CreateCheckoutSessionLineItems {
84 price: Some(stripe_price_id.to_string()),
85 quantity: Some(1),
86 ..Default::default()
87 }]);
88 params.success_url = Some("https://zed.dev/billing/success");
89
90 CheckoutSession::create(&stripe_client, params).await?
91 };
92
93 Ok(Json(CreateBillingSubscriptionResponse {
94 checkout_session_url: checkout_session
95 .url
96 .ok_or_else(|| anyhow!("no checkout session URL"))?,
97 }))
98}
99
100#[derive(Debug, Deserialize)]
101#[serde(rename_all = "snake_case")]
102enum ManageSubscriptionIntent {
103 /// The user intends to cancel their subscription.
104 Cancel,
105}
106
107#[derive(Debug, Deserialize)]
108struct ManageBillingSubscriptionBody {
109 github_user_id: i32,
110 intent: ManageSubscriptionIntent,
111 /// The ID of the subscription to manage.
112 ///
113 /// If not provided, we will try to use the active subscription (if there is only one).
114 subscription_id: Option<BillingSubscriptionId>,
115}
116
117#[derive(Debug, Serialize)]
118struct ManageBillingSubscriptionResponse {
119 billing_portal_session_url: String,
120}
121
122/// Initiates a Stripe customer portal session for managing a billing subscription.
123async fn manage_billing_subscription(
124 Extension(app): Extension<Arc<AppState>>,
125 extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
126) -> Result<Json<ManageBillingSubscriptionResponse>> {
127 let user = app
128 .db
129 .get_user_by_github_user_id(body.github_user_id)
130 .await?
131 .ok_or_else(|| anyhow!("user not found"))?;
132
133 let Some(stripe_client) = app.stripe_client.clone() else {
134 log::error!("failed to retrieve Stripe client");
135 Err(Error::Http(
136 StatusCode::NOT_IMPLEMENTED,
137 "not supported".into(),
138 ))?
139 };
140
141 let customer = app
142 .db
143 .get_billing_customer_by_user_id(user.id)
144 .await?
145 .ok_or_else(|| anyhow!("billing customer not found"))?;
146 let customer_id = CustomerId::from_str(&customer.stripe_customer_id)
147 .context("failed to parse customer ID")?;
148
149 let subscription = if let Some(subscription_id) = body.subscription_id {
150 app.db
151 .get_billing_subscription_by_id(subscription_id)
152 .await?
153 .ok_or_else(|| anyhow!("subscription not found"))?
154 } else {
155 // If no subscription ID was provided, try to find the only active subscription ID.
156 let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
157 if subscriptions.len() > 1 {
158 Err(anyhow!("user has multiple active subscriptions"))?;
159 }
160
161 subscriptions
162 .into_iter()
163 .next()
164 .ok_or_else(|| anyhow!("user has no active subscriptions"))?
165 };
166
167 let flow = match body.intent {
168 ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
169 type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
170 after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
171 type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
172 redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
173 return_url: "https://zed.dev/billing".into(),
174 }),
175 ..Default::default()
176 }),
177 subscription_cancel: Some(
178 stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
179 subscription: subscription.stripe_subscription_id,
180 retention: None,
181 },
182 ),
183 ..Default::default()
184 },
185 };
186
187 let mut params = CreateBillingPortalSession::new(customer_id);
188 params.flow_data = Some(flow);
189 params.return_url = Some("https://zed.dev/billing");
190
191 let session = BillingPortalSession::create(&stripe_client, params).await?;
192
193 Ok(Json(ManageBillingSubscriptionResponse {
194 billing_portal_session_url: session.url,
195 }))
196}